dair_pll.experiment
Defines interfaces for various learning experiments to be run.
Current supported experiment types include:
SupervisedLearningExperiment
: An experiment where aSystem
is learned to mimic a dataset of trajectories.
- class dair_pll.experiment.TrainingState(trajectory_set_split_indices, best_learned_system_state, current_learned_system_state, optimizer_state, epoch=1, epochs_since_best=0, best_valid_loss=<factory>, wandb_run_id=None, finished_training=False)[source]
Bases:
object
Dataclass to store a complete summary of the state of training process.
-
trajectory_set_split_indices:
typing.Tuple
[torch.Tensor
,torch.Tensor
,torch.Tensor
] Which trajectory indices are in train/valid/test sets.
-
best_learned_system_state:
dict
State of learned system when it had the best validation loss so far.
-
optimizer_state:
dict
Current state of training
torch.optim.Optimizer
.
-
best_valid_loss:
torch.Tensor
Value of best validation loss so far.
-
wandb_run_id:
typing.Optional
[str
] = None If using W&B, the ID of the run associated with this experiment.
-
trajectory_set_split_indices:
- dair_pll.experiment.EpochCallbackCallable
Type hint for extra callback to be called on each epoch of training.
- Parameters:
epoch – Current epoch.
learned_system – Partially-trained learned system.
train_loss – Current epoch’s average training loss.
best_valid_loss – Best validation loss so far.
- dair_pll.experiment.LossCallbackCallable
Callback to evaluate loss on batch of trajectory slices.
By default, set to prediction loss (
SupervisedLearningExperiment.prediction_loss()
)- Parameters:
x_past –
(*,t_history,space.n_x)
previous states in slice.x_future –
(*,t_prediction,space.n_x)
future states in slice.system – system on which to evaluate loss
keep_batch – whether or not to collapse batch into a single scalar.
- Returns:
(*,)
or scalar loss.
- dair_pll.experiment.default_epoch_callback(epoch, _learned_system, train_loss, best_valid_loss)[source]
Default
EpochCallbackCallable
which prints epoch, training loss, and best validation loss so far.- Return type:
- class dair_pll.experiment.SupervisedLearningExperiment(config)[source]
Bases:
ABC
Supervised learning experiment.
Implements the training and evaluation processes for a supervised learning experiment, where a
System
is learned to capture a dataset of trajectories.The dataset of trajectories is encapsulated in a
ExperimentDataManager
object. This dataset is either stored to disc by the user, or alternatively is generated from the experiment’s base system.The base systemis a
System
with the sameStateSpace
as the system to be learned.Training is completed via a Pytorch
Optimizer
.The training process keeps track of various statistics about the learning process, and optionally logs the learned system’s
SystemSummary
to Tensorboard on each epoch.-
wandb_manager:
typing.Optional
[dair_pll.wandb_manager.WeightsAndBiasesManager
] Optional tensorboard interface.
-
config:
dair_pll.experiment_config.SupervisedLearningExperimentConfig
Configuration of the experiment.
-
space:
dair_pll.state_space.StateSpace
State space of experiment, inferred from base system.
-
loss_callback:
typing.Optional
[typing.Callable
[[Tensor
,Tensor
,System
,bool
],Tensor
]] Callback function for loss, defaults to prediction loss.
-
learning_data_manager:
typing.Optional
[dair_pll.dataset_management.ExperimentDataManager
] Manager of trajectory data used in learning process.
- abstract get_base_system()[source]
Abstract callback function to construct base system from system config.
- Return type:
- Returns:
Experiment’s base system.
- get_oracle_system()[source]
Abstract callback function to construct oracle system for experiment.
Conceptually, the oracle system is an ideal system to compare the learned system against. By default, the oracle system is simply the base system. However, in some scenarios, a different type of oracle is appropriate. For example, if the learned system is recurrent, the oracle system might most appropriately take a recurrent slice of initial states, process them with a Kalman Filter for the base system, and then predict the future.
- Return type:
- Returns:
Experiment’s oracle system.
- abstract get_learned_system(train_states)[source]
Abstract callback function to construct learnable system for experiment.
Optionally, learned system can be initialized to depend on the training dataset.
- trajectory_predict(x, system, do_detach=False)[source]
Predict from full lists of trajectories.
Preloads initial conditions from the first
t_skip + 1
elements of each trajectory.- Parameters:
- Return type:
- Returns:
List of
(T - t_skip - 1, space.n_x)
predicted trajectories.List of
(T - t_skip - 1, space.n_x)
target trajectories.
- prediction_loss(x_past, x_future, system, keep_batch=False)[source]
Default
LossCallbackCallable
which evaluates to system’s \(l_2\) prediction error on batch:\[\mathcal{L}(x_{p,i,\cdot}, x_{f,i,\cdot}) = \sum_{j} ||\hat x_{f, i,j} - x_{f,i,j}||^2,\]where \(x_{p,i,\cdot}, x_{f,i,\cdot}\) are the \(i\)th elements of the past and future batches; and \(\hat x_{f,i,j}\) is the \(j\)-step forward prediction of the model from the past batch.
See
LossCallbackCallable
for additional type signature info.- Return type:
- batch_loss(x_past, x_future, system, keep_batch=False)[source]
Runs
loss_callback
(aLossCallbackCallable
) on the given batch.- Return type:
- train_epoch(data, system, optimizer=None)[source]
Train learned model for a single epoch. Takes gradient steps in the learned parameters if
optimizer
is provided.- Parameters:
data (
DataLoader
) – Training dataset.system (
System
) – System to be trained.optimizer (
Optional
[Optimizer
]) – Optimizer which trains system.
- Return type:
- Returns:
Scalar average training loss observed during epoch.
- base_and_learned_comparison_summary(statistics, learned_system)[source]
Extracts a
SystemSummary
that compares the base system to the learned system.- Parameters:
- Return type:
- Returns:
Summary of comparison between systems.
- write_to_wandb(epoch, learned_system, statistics)[source]
Extracts and writes summary of training progress to Tensorboard.
- per_epoch_evaluation(epoch, learned_system, train_loss, training_duration)[source]
Evaluates and logs training progress at end of an epoch.
Runs evaluation on full slice datasets, as well as a handful of trajectories.
Optionally logs the results on tensorboard via
write_to_tensorboard()
.
- setup_training()[source]
Sets up initial condition for training process.
Attempts to load initial condition from disk as a
TrainingState
. Otherwise, a fresh training process is started.- Return type:
- Returns:
Initial learned system. Pytorch optimizer. Current state of training process.
- train(epoch_callback=<function 'default_epoch_callback'>)[source]
Run training process for experiment.
Terminates training with early stopping, parameters for which are set in
config
.
- evaluate_systems_on_sets(systems, sets)[source]
Evaluate given systems on trajectory sets.
Builds a “statistics” dictionary containing a thorough evaluation each system on each set, containing the following:
Single step and trajectory prediction losses.
Squared norms of velocity and delta-velocity (for normalization).
Sample target and prediction trajectories.
Auxiliary trajectory comparisons defined in
dair_pll.state_space.StateSpace.auxiliary_comparisons()
Summary statistics of the above where applicable.
- Parameters:
- Return type:
- Returns:
Statistics dictionary.
Warning
Currently assumes prediction horizon of 1.
- generate_results(epoch_callback=<function 'default_epoch_callback'>)[source]
Get the final learned model and results/statistics of experiment. Along with the model corresponding to best validation loss, this will return previously saved results on disk if they already exist, or run the experiment to generate them if they don’t.
-
wandb_manager: