dair_pll.dataset_management

Classes for generating and managing datasets for experiments.

Centers around the ExperimentDataManager type, which transforms a set of trajectories saved to disk for various tasks encountered during an experiment.

class dair_pll.dataset_management.TrajectorySliceDataset(config)[source]

Bases: Dataset

Dataset of trajectory slices for training.

Given a list of trajectories and a TrajectorySliceConfig, generates sets of (previous states, future states) transition pairs to be used with the training loss of an experiment.

Extends torch.utils.data.Dataset type in order to be managed in the training process with a torch.utils.data.DataLoader.

Parameters:

config (TrajectorySliceConfig) – configuration object for slice dataset.

config: dair_pll.data_config.TrajectorySliceConfig

Slice configuration describing durations and start index.

previous_states_slices: typing.List[torch.Tensor]

Initial conditions of duration self.config.t_history .

future_states_slices: typing.List[torch.Tensor]

Future targets of duration self.config.t_prediction .

add_slices_from_trajectory(trajectory)[source]

Incorporate trajectory into dataset as a set of slices.

Parameters:

trajectory (Tensor) – (T, *) state trajectory.

Return type:

None

class dair_pll.dataset_management.TrajectorySet(slices, trajectories=<factory>, indices=<factory>)[source]

Bases: object

Dataclass encapsulating the various transforms of a set of trajectories that are used during the training and evaluation process, including:

  • Slices for training;

  • Entire trajectories for evaluation; and

  • Indices associated with on-disk location for experiment resumption.

slices: dair_pll.dataset_management.TrajectorySliceDataset

Trajectories rendered as a dataset of time slices.

trajectories: typing.List[torch.Tensor]

Trajectories in their raw format.

indices: torch.Tensor

Indices associated with on-disk filenames.

add_trajectories(trajectory_list, indices)[source]

Add new subset of trajectories to set.

Parameters:
  • trajectory_list (List[Tensor]) – List of new (T, *) state trajectories.

  • indices (Tensor) – indices associated with on-disk filenames.

Return type:

None

class dair_pll.dataset_management.ExperimentDataManager(storage, config, initial_split=None, use_ground_truth=False)[source]

Bases: object

Management object for maintaining training, validation, and testing data for an experiment.

Loads trajectories stored in standard location associated with provided storage directory; splits into train/valid/test sets; and instantiates transformations for each set of data as a TrajectorySet.

Parameters:
  • storage (str) – Storage directory to source trajectories from.

  • config (DataConfig) – Configuration object.

  • initial_split (Optional[Tuple[Tensor, Tensor, Tensor]]) – Optionally, lists of trajectory indices that should be sorted into (train, valid, test) sets from the beginning.

  • use_ground_truth (bool) – Whether trajectories should be sourced from ground-truth or learning data.

trajectory_dir: str

Directory in which trajectory files are stored.

config: dair_pll.data_config.DataConfig

Configuration for manipulating data.

train_set: dair_pll.dataset_management.TrajectorySet

Training trajectory set.

valid_set: dair_pll.dataset_management.TrajectorySet

Validation trajectory set.

test_set: dair_pll.dataset_management.TrajectorySet

Test trajectory set.

n_sorted: int

Number of files on disk split into (train, valid, test) sets so far.

trajectory_set_indices()[source]

The sets of indices associated with the (train, valid, test) trajectories.

Return type:

Tuple[Tensor, Tensor, Tensor]

make_empty_trajectory_set()[source]

Instantiates an empty TrajectorySet associated with the time slice configuration contained in config.

Return type:

TrajectorySet

extend_trajectory_sets(index_lists)[source]

Supplement each of (train, valid, test) trajectory sets with provided trajectories, listed by their on-disk indices.

Parameters:

index_lists (Tuple[Tensor, Tensor, Tensor]) – Lists of trajectory indices for each set.

Return type:

None

get_updated_trajectory_sets()[source]

Returns an up-to-date partition of trajectories on disk.

Checks if some trajectories on disk have yet to be sorted, and supplements the (train, valid, test) sets with these additional trajectories before returning the updated sets.

Return type:

Tuple[TrajectorySet, TrajectorySet, TrajectorySet]

Returns:

Training set. Validation set. Test set.