dair_pll.experiment

Defines interfaces for various learning experiments to be run.

Current supported experiment types include:

class dair_pll.experiment.TrainingState(trajectory_set_split_indices, best_learned_system_state, current_learned_system_state, optimizer_state, epoch=1, epochs_since_best=0, best_valid_loss=<factory>, wandb_run_id=None, finished_training=False)[source]

Bases: object

Dataclass to store a complete summary of the state of training process.

trajectory_set_split_indices: typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Which trajectory indices are in train/valid/test sets.

best_learned_system_state: dict

State of learned system when it had the best validation loss so far.

current_learned_system_state: dict

Current state of learned system.

optimizer_state: dict

Current state of training torch.optim.Optimizer.

epoch: int = 1

Current epoch.

epochs_since_best: int = 0

Number of epochs since best validation loss so far was achieved.

best_valid_loss: torch.Tensor

Value of best validation loss so far.

wandb_run_id: typing.Optional[str] = None

If using W&B, the ID of the run associated with this experiment.

finished_training: bool = False

Whether training has finished.

dair_pll.experiment.EpochCallbackCallable

Type hint for extra callback to be called on each epoch of training.

Parameters:
  • epoch – Current epoch.

  • learned_system – Partially-trained learned system.

  • train_loss – Current epoch’s average training loss.

  • best_valid_loss – Best validation loss so far.

alias of Callable[[int, System, Tensor, Tensor], None]

dair_pll.experiment.LossCallbackCallable

Callback to evaluate loss on batch of trajectory slices.

By default, set to prediction loss ( SupervisedLearningExperiment.prediction_loss())

Parameters:
  • x_past(*,t_history,space.n_x) previous states in slice.

  • x_future(*,t_prediction,space.n_x) future states in slice.

  • system – system on which to evaluate loss

  • keep_batch – whether or not to collapse batch into a single scalar.

Returns:

(*,) or scalar loss.

alias of Callable[[Tensor, Tensor, System, bool], Tensor]

dair_pll.experiment.default_epoch_callback(epoch, _learned_system, train_loss, best_valid_loss)[source]

Default EpochCallbackCallable which prints epoch, training loss, and best validation loss so far.

Return type:

None

class dair_pll.experiment.SupervisedLearningExperiment(config)[source]

Bases: ABC

Supervised learning experiment.

Implements the training and evaluation processes for a supervised learning experiment, where a System is learned to capture a dataset of trajectories.

The dataset of trajectories is encapsulated in a ExperimentDataManager object. This dataset is either stored to disc by the user, or alternatively is generated from the experiment’s base system.

The base systemis a System with the same StateSpace as the system to be learned.

Training is completed via a Pytorch Optimizer.

The training process keeps track of various statistics about the learning process, and optionally logs the learned system’s SystemSummary to Tensorboard on each epoch.

wandb_manager: typing.Optional[dair_pll.wandb_manager.WeightsAndBiasesManager]

Optional tensorboard interface.

config: dair_pll.experiment_config.SupervisedLearningExperimentConfig

Configuration of the experiment.

space: dair_pll.state_space.StateSpace

State space of experiment, inferred from base system.

loss_callback: typing.Optional[typing.Callable[[Tensor, Tensor, System, bool], Tensor]]

Callback function for loss, defaults to prediction loss.

learning_data_manager: typing.Optional[dair_pll.dataset_management.ExperimentDataManager]

Manager of trajectory data used in learning process.

abstract get_base_system()[source]

Abstract callback function to construct base system from system config.

Return type:

System

Returns:

Experiment’s base system.

get_oracle_system()[source]

Abstract callback function to construct oracle system for experiment.

Conceptually, the oracle system is an ideal system to compare the learned system against. By default, the oracle system is simply the base system. However, in some scenarios, a different type of oracle is appropriate. For example, if the learned system is recurrent, the oracle system might most appropriately take a recurrent slice of initial states, process them with a Kalman Filter for the base system, and then predict the future.

Return type:

System

Returns:

Experiment’s oracle system.

abstract get_learned_system(train_states)[source]

Abstract callback function to construct learnable system for experiment.

Optionally, learned system can be initialized to depend on the training dataset.

Parameters:

train_states (Tensor) – (*, space.n_x) batch of all states in training set.

Return type:

System

Returns:

Experiment’s learnable system.

get_optimizer(learned_system)[source]

Constructs optimizer for experiment.

Parameters:

learned_system (System) – System to be trained.

Return type:

Optimizer

Returns:

Optimizer for training.

batch_predict(x_past, system)[source]

Predict forward in time from initial conditions.

Parameters:
  • x_past (Tensor) – (*, t_history, space.n_x) batch of initial states.

  • system (System) – System to run prediction on.

Return type:

Tensor

Returns:

(*, t_prediction, space.n_x) batch of predicted future states.

trajectory_predict(x, system, do_detach=False)[source]

Predict from full lists of trajectories.

Preloads initial conditions from the first t_skip + 1 elements of each trajectory.

Parameters:
  • x (List[Tensor]) – List of (T, space.n_x) trajectories.

  • system (System) – System to run prediction on.

  • do_detach (bool) – Whether to detach each prediction from the computation graph; useful for memory management for large groups of trajectories.

Return type:

Tuple[List[Tensor], List[Tensor]]

Returns:

List of (T - t_skip - 1, space.n_x) predicted trajectories.

List of (T - t_skip - 1, space.n_x) target trajectories.

prediction_loss(x_past, x_future, system, keep_batch=False)[source]

Default LossCallbackCallable which evaluates to system’s \(l_2\) prediction error on batch:

\[\mathcal{L}(x_{p,i,\cdot}, x_{f,i,\cdot}) = \sum_{j} ||\hat x_{f, i,j} - x_{f,i,j}||^2,\]

where \(x_{p,i,\cdot}, x_{f,i,\cdot}\) are the \(i\)th elements of the past and future batches; and \(\hat x_{f,i,j}\) is the \(j\)-step forward prediction of the model from the past batch.

See LossCallbackCallable for additional type signature info.

Return type:

Tensor

batch_loss(x_past, x_future, system, keep_batch=False)[source]

Runs loss_callback (a LossCallbackCallable) on the given batch.

Return type:

Tensor

train_epoch(data, system, optimizer=None)[source]

Train learned model for a single epoch. Takes gradient steps in the learned parameters if optimizer is provided.

Parameters:
Return type:

Tensor

Returns:

Scalar average training loss observed during epoch.

base_and_learned_comparison_summary(statistics, learned_system)[source]

Extracts a SystemSummary that compares the base system to the learned system.

Parameters:
  • statistics (Dict) – Dictionary of training statistics.

  • learned_system (System) – Most updated version of system during training.

Return type:

SystemSummary

Returns:

Summary of comparison between systems.

write_to_wandb(epoch, learned_system, statistics)[source]

Extracts and writes summary of training progress to Tensorboard.

Parameters:
  • epoch (int) – Current epoch.

  • learned_system (System) – System being trained.

  • statistics (Dict) – Summary statistics for learning process.

Return type:

None

per_epoch_evaluation(epoch, learned_system, train_loss, training_duration)[source]

Evaluates and logs training progress at end of an epoch.

Runs evaluation on full slice datasets, as well as a handful of trajectories.

Optionally logs the results on tensorboard via write_to_tensorboard().

Parameters:
  • epoch (int) – Current epoch.

  • learned_system (System) – System being trained.

  • train_loss (Tensor) – Scalar training loss of epoch.

  • training_duration (float) – Duration of epoch training in seconds.

Return type:

Tensor

Returns:

Scalar validation set loss.

setup_training()[source]

Sets up initial condition for training process.

Attempts to load initial condition from disk as a TrainingState. Otherwise, a fresh training process is started.

Return type:

Tuple[System, Optimizer, TrainingState]

Returns:

Initial learned system. Pytorch optimizer. Current state of training process.

train(epoch_callback=<function 'default_epoch_callback'>)[source]

Run training process for experiment.

Terminates training with early stopping, parameters for which are set in config.

Parameters:

epoch_callback (Callable[[int, System, Tensor, Tensor], None]) – Callback function at end of each epoch.

Return type:

Tuple[Tensor, Tensor, System]

Returns:

Final-epoch training loss.

Best-seen validation set loss.

Fully-trained system, with parameters corresponding to best-seen validation loss.

evaluate_systems_on_sets(systems, sets)[source]

Evaluate given systems on trajectory sets.

Builds a “statistics” dictionary containing a thorough evaluation each system on each set, containing the following:

  • Single step and trajectory prediction losses.

  • Squared norms of velocity and delta-velocity (for normalization).

  • Sample target and prediction trajectories.

  • Auxiliary trajectory comparisons defined in dair_pll.state_space.StateSpace.auxiliary_comparisons()

  • Summary statistics of the above where applicable.

Parameters:
Return type:

Dict[str, Union[List, float, ndarray]]

Returns:

Statistics dictionary.

Warning

Currently assumes prediction horizon of 1.

generate_results(epoch_callback=<function 'default_epoch_callback'>)[source]

Get the final learned model and results/statistics of experiment. Along with the model corresponding to best validation loss, this will return previously saved results on disk if they already exist, or run the experiment to generate them if they don’t.

Parameters:

epoch_callback (Callable[[int, System, Tensor, Tensor], None]) – Callback function at end of each epoch.

Return type:

Tuple[System, Dict[str, Union[List, float, ndarray]]]

Returns:

Fully-trained system, with parameters corresponding to best-seen

validation loss.

Statistics dictionary.