dair_pll.dataset_management
Classes for generating and managing datasets for experiments.
Centers around the ExperimentDataManager type, which transforms a
set of trajectories saved to disk for various tasks encountered during an
experiment.
- class dair_pll.dataset_management.TrajectorySliceDataset(config)[source]
Bases:
DatasetDataset of trajectory slices for training.
Given a list of trajectories and a
TrajectorySliceConfig, generates sets of (previous states, future states) transition pairs to be used with the training loss of an experiment.Extends
torch.utils.data.Datasettype in order to be managed in the training process with atorch.utils.data.DataLoader.- Parameters:
config (
TrajectorySliceConfig) – configuration object for slice dataset.
-
config:
dair_pll.data_config.TrajectorySliceConfig Slice configuration describing durations and start index.
-
previous_states_slices:
typing.List[torch.Tensor] Initial conditions of duration
self.config.t_history.
-
future_states_slices:
typing.List[torch.Tensor] Future targets of duration
self.config.t_prediction.
- class dair_pll.dataset_management.TrajectorySet(slices, trajectories=<factory>, indices=<factory>)[source]
Bases:
objectDataclass encapsulating the various transforms of a set of trajectories that are used during the training and evaluation process, including:
Slices for training;
Entire trajectories for evaluation; and
Indices associated with on-disk location for experiment resumption.
-
slices:
dair_pll.dataset_management.TrajectorySliceDataset Trajectories rendered as a dataset of time slices.
-
trajectories:
typing.List[torch.Tensor] Trajectories in their raw format.
-
indices:
torch.Tensor Indices associated with on-disk filenames.
- class dair_pll.dataset_management.ExperimentDataManager(storage, config, initial_split=None, use_ground_truth=False)[source]
Bases:
objectManagement object for maintaining training, validation, and testing data for an experiment.
Loads trajectories stored in standard location associated with provided storage directory; splits into train/valid/test sets; and instantiates transformations for each set of data as a
TrajectorySet.- Parameters:
storage (
str) – Storage directory to source trajectories from.config (
DataConfig) – Configuration object.initial_split (
Optional[Tuple[Tensor,Tensor,Tensor]]) – Optionally, lists of trajectory indices that should be sorted into (train, valid, test) sets from the beginning.use_ground_truth (
bool) – Whether trajectories should be sourced from ground-truth or learning data.
-
config:
dair_pll.data_config.DataConfig Configuration for manipulating data.
-
train_set:
dair_pll.dataset_management.TrajectorySet Training trajectory set.
-
valid_set:
dair_pll.dataset_management.TrajectorySet Validation trajectory set.
-
test_set:
dair_pll.dataset_management.TrajectorySet Test trajectory set.
- trajectory_set_indices()[source]
The sets of indices associated with the (train, valid, test) trajectories.
- make_empty_trajectory_set()[source]
Instantiates an empty
TrajectorySetassociated with the time slice configuration contained inconfig.- Return type:
- extend_trajectory_sets(index_lists)[source]
Supplement each of (train, valid, test) trajectory sets with provided trajectories, listed by their on-disk indices.
- get_updated_trajectory_sets()[source]
Returns an up-to-date partition of trajectories on disk.
Checks if some trajectories on disk have yet to be sorted, and supplements the (train, valid, test) sets with these additional trajectories before returning the updated sets.
- Return type:
- Returns:
Training set. Validation set. Test set.