train

Flexibly parameterized and feature-rich training loop for models.

Allows for warm-up periods as well as early stopping with recovery of the best checkpoint from CPU memory or disk. Use the provided callback examples to monitor training progress or derive your own from the respective base classes. Custom templates for train and test data to seamlessly work with the training loop are also provided.

class Trainer(loss, optimizer=Curry(AdamW), batch_size=64, max_epochs=100, scheduler=Curry(NoSchedule), warmup=0, batch_step=False, patience=None, max_n=None, step_freq=1, clip_grad=1.0, checkpoint=InMemory(), show_progress=True, step_cbs=(), cb_freq=1, epoch_cbs=EpochPrinter(print), train_cbs=TrainPrinter(print))[source]

Bases: ArgRepr

Train and, optionally, evaluate a model with early stopping.

Parameters:
  • loss (Module) – PyTorch module that accepts the output(s) of the model to train as first argument(s) and the target as last argument and that produces the (scalar) loss to minimize, i.e., reduction must be “mean” or “sum”.

  • optimizer (Curry[Optimizer], optional) – A curry of a preconfigured PyTorch Optimizer (or some other custom construct) that returns a fully configured PyTorch Optimizer when called with the model.parameters() of the model to train. Defaults to AdamW with its default parameters.

  • batch_size (int, optional) – Size of the mini-batches to request from the training data. Defaults to 64.

  • max_epochs (int, optional) – Maximum number of epochs to train for. Defaults to 100.

  • scheduler (Curry[LRScheduler], optional) – A curry of a PyTorch learning-rate scheduler (or some other custom construct) that returns a fully configured learning-rate scheduler when called with a PyTorch optimizer. Defaults to never changing the learning rate at all.

  • warmup (int, optional) – In the beginning of training, the learning-rate scheduler will be stepped after every optimizer step for this many times. Afterward, it will only be stepped at the end of each epoch. Defaults to 0, which results in no warmup and the learning-rate scheduler being stepped at the end of the first epoch for the first time.

  • batch_step (bool, optional) – Whether to step the learning-rate scheduler after each batch or after each epoch once the warmup period is over. Default to False.

  • patience (int, optional) – If patience is not None and smaller than max_epochs, early stopping is active. A snapshot of the model’s state is taken after each epoch that improved the loss below its last minimum. If no improvement occurs for patience epochs, model training is stopped (even if max_epochs has not been reached yet) and the model is reset to its best state.

  • max_n (int, optional) – Maximum number of data points to take from the training data to evaluate the train loss after each epoch. Defaults to number of data points in the test set (if present) or the train set (if not).

  • step_freq (int, optional) – For how many batches to accumulate gradients before taking an optimization step. Defaults to 1, which corresponds to no accumulation.

  • clip_grad (float, optional) – Clip gradients such that their overall norm is capped by the given value. Defaults to 1.0

  • checkpoint (Checkpoint, optional) – Whenever the train (or test) loss after an epoch is smaller than the loss after the last, a new snapshot of the model state is saved by calling the save method of the checkpoint instance. Defaults to InMemory.

  • show_progress (bool, optional) – Whether to provide visual feedback to the console while training an epoch in the form of a progress bar. Defaults to True

  • step_cbs (iterable of StepCallback, optional) – All per-step callbacks will be called every cb_freq batches with the train loss of the last batch and the current learning rate. Defaults to an empty tuple, which does nothing.

  • cb_freq (int, optional) – Number of batches to skip before calling step_cb again. Defaults to 1.

  • epoch_cbs (iterable of EpochCallback, optional) – All epoch callbacks will be called after each epoch with epoch, train loss, test loss, and current learning rate. Defaults to a single EpochPrinter.

  • train_cbs (iterable of TrainCallback, optional) – All train callbacks will be called after training finished with last epoch, epoch with the best loss, the best loss itself, whether max_epochs was exhausted, and with the training history in the form of a dictionary of lists with train loss, test loss and learning rate. Defaults to a single TrainPrinter.

Important

optimizer

Because the (partial) optimizer will simply be completed with model.parameters(), parameter groups are not supported.

scheduler

During warmup, the scheduler is called after each optimizer step and, afterward, at the end of each epoch. The scheduler thus needs to be aware of the warmup settings and act accordingly.

step_freq

If the reduction of the loss is “mean”, the loss of each batch is divided by this number before performing the backward pass. Strictly speaking, each batch should thus have the exact same number of data points in this case. Furthermore, complications might arise when the model to train contains elements like BatchNorm.

resume(model, train, test=None)[source]

Resume model training from the best epoch checkpointed so far.

Parameters:
  • model (Module) – PyTorch Module to train.

  • train (TrainDataBase) – Training data.

  • test (TestDataBase, optional) – Hold-out data to compute test loss. If configured, early stopping will track the test loss if test is given and the train loss if it is not.

Returns:

The trained Pytorch model.

Return type:

Module

Note

This is safe to use even if you have never trained your model before, and you are starting from scratch.

Important

The model to train must always return a tuple of tensors, not just a single tensor. If it produces only one tensor, return it as a 1-tuple of tensors.

property scale

Scaling factor for the accumulated loss before the backward pass.

train(model, train, test=None)[source]

Train a fresh model from scratch, starting from a clean slate.

Parameters:
  • model (Resettable) – PyTorch Module to train. Must have a reset_parameters() method that can be called without any parameters to re-initialize all trainable model parameters and buffers.

  • train (TrainDataBase) – Training data.

  • test (TestDataBase, optional) – Hold-out data to compute test loss. If configured, early stopping will track the test loss if test is given and the train loss if it is not.

Returns:

The trained Pytorch model.

Return type:

Module

Raises:

TrainError – If the model does not have a reset_parameters() method.

Important

The model to train must always return a tuple of tensors, not just a single tensor. If it produces only one tensor, return it as a 1-tuple of tensors.

Warning

Model training is (re-)started from scratch every time this method is called. The training history is erased, all internal model parameters are reset to a pristine initial state, and previously saved checkpoints are irrevocably deleted!

class InMemory[source]

Bases: Checkpoint

Checkpoint in CPU memory regardless of the device training runs on.

property counter

How many checkpoints were saved since the last reset.

load(model, optimizer=None, scheduler=None)

User-facing method for loading a checkpoint during training.

The state retrieved from the checkpoint is merged into the state of the objects to update, such that loading a checkpoint before one has been saved results in unchanged objects.

Parameters:
  • model (Module) – The model to load the state for in-place.

  • optimizer (Optimizer, optional) – The optimizer to load the state for in-place. Defaults to None.

  • scheduler (LRScheduler, optional) – The scheduler to load the state for in-place. Defaults to None.

Returns:

  • epoch (int) – The epoch of the checkpoint.

  • loss (float) – The loss from the checkpoint.

reset_parameters()

Hard-resets the checkpoint into a pristine, initial state.

save(epoch, loss, model, optimizer=None, scheduler=None)

User-facing method for saving a checkpoint during training.

Parameters:
  • epoch (int) – The current epoch.

  • loss (float) – The (train or test) loss at the current epoch.

  • model (Module) – The model to checkpoint the state of.

  • optimizer (Optimizer, optional) – The optimizer to checkpoint the state of. Defaults to None.

  • scheduler (LRScheduler, optional) – The scheduler to checkpoint the state of. Defaults to None.

class OnDisk(path, create=False, not_found=NotFound.WARN)[source]

Bases: Checkpoint

Checkpoint the current state of training on disk.

Parameters:
  • path (str) – Full path to and name of the file used to persist state.

  • create (bool, optional) – What to do if the directory where the checkpoint should be saved does not exist. Defaults to False.

  • not_found (str, optional) – What to do if a checkpoint is loaded from the specified path but none is there yet. One of “ignore”, “warn”, or “raise”. Defaults to “raise”. Use the NotFound enum to avoid typos!

See also

NotFound

property counter

How many checkpoints were saved since the last reset.

load(model, optimizer=None, scheduler=None)

User-facing method for loading a checkpoint during training.

The state retrieved from the checkpoint is merged into the state of the objects to update, such that loading a checkpoint before one has been saved results in unchanged objects.

Parameters:
  • model (Module) – The model to load the state for in-place.

  • optimizer (Optimizer, optional) – The optimizer to load the state for in-place. Defaults to None.

  • scheduler (LRScheduler, optional) – The scheduler to load the state for in-place. Defaults to None.

Returns:

  • epoch (int) – The epoch of the checkpoint.

  • loss (float) – The loss from the checkpoint.

reset_parameters()

Hard-resets the checkpoint into a pristine, initial state.

save(epoch, loss, model, optimizer=None, scheduler=None)

User-facing method for saving a checkpoint during training.

Parameters:
  • epoch (int) – The current epoch.

  • loss (float) – The (train or test) loss at the current epoch.

  • model (Module) – The model to checkpoint the state of.

  • optimizer (Optimizer, optional) – The optimizer to checkpoint the state of. Defaults to None.

  • scheduler (LRScheduler, optional) – The scheduler to checkpoint the state of. Defaults to None.

class StepPrinter(printer=<built-in function print>, sep=', ')[source]

Bases: ArgRepr, StepCallback

Step callback assembling a one-liner on per-batch training progress.

Parameters:
  • printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin print function, but could also be a logging command.

  • sep (str, optional) – The items concatenated into a one-line message will be separated by this string. Default to “ “.

__call__(train_loss, learning_rate, gradient_norm)[source]

Assemble one-liner with loss, lr, and norm, and call the printer.

close()[source]

Does nothing because there is nothing to close.

class EpochPrinter(printer=<built-in function print>)[source]

Bases: ArgRepr, EpochCallback

Epoch callback assembling an informative message on training progress.

Parameters:

printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin print function, but could also be a logging command.

__call__(epoch, train_loss, test_loss, learning_rate, model, data)[source]

Assemble training-progress message and call the printer with it.

close()[source]

Does nothing because there is nothing to close.

class TrainPrinter(printer=<built-in function print>)[source]

Bases: ArgRepr, TrainCallback

Train callback assembling an informative message when training ends.

Parameters:

printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin print function, but could also be a logging command.

__call__(epoch, best_epoch, best_loss, max_epochs_reached, history)[source]

Assemble a summary of model training and call the printer with it.

Parameters:
  • epoch (int) – The last epoch in the training loop.

  • best_epoch (int) – The epoch with the lowest loss encountered.

  • best_loss (float) – The lowest loss encountered.

  • max_epochs_reached (bool) – Whether the maximum number of epochs was exhausted or not.

  • history (History) – Dictionary with lists of train losses, test losses, and learning rates.

class History[source]

Bases: TypedDict

Summary of metrics passed to the callback when training is finished.

lr

List of learning rates.

test_loss

List of losses evaluated on test data.

train_loss

List of losses evaluated on train data.

Schedulers

class LinearInverse(warmup=0, power=0.5)[source]

Bases: ArgRepr

Scale up learning rate during warmup before decaying with inverse power.

Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as lr_lambda argument to PyTorch’s LambdaLR learning-rate scheduler.

Parameters:
  • warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.

  • power (float, optional) – After warmup steps, the learning rate is scaled down with the inverse step number taken to the power of this number. Values are cut to lie in the interval [0.5, 1.0]. Defaults to 0.5, the slowest decay.

__call__(step)[source]

Learning rate scaling factor depending on the step.

Parameters:

step (int) – Step to return the learning-rate scaling factor for.

Returns:

The learning-rate scaling factor

Return type:

float

class LinearExponential(warmup=0, gamma=0.95)[source]

Bases: ArgRepr

Scale up learning rate during warmup before decaying exponentially.

Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as lr_lambda argument to PyTorch’s LambdaLR learning-rate scheduler.

Parameters:
  • warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.

  • gamma (float, optional) – After warmup steps, the learning rate is scaled down with this number to the power of the step number. Therefore, it must lie in the interval (0.0, 1.0). Defaults to 0.95.

__call__(step)[source]

Learning rate scaling factor depending on the step.

Parameters:

step (int) – Step to return the learning-rate scaling factor for.

Returns:

The learning-rate scaling factor

Return type:

float

class LinearCosine(warmup=0, cooldown=100)[source]

Bases: ArgRepr

Scale up learning rate during warmup before decaying with cosine.

Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as lr_lambda argument to PyTorch’s LambdaLR learning-rate scheduler.

Parameters:
  • warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.

  • cooldown (int, optional) – After warmup steps, the learning rate is scaled down with a cosine function for this many steps. Defaults to 100, but must be at least 1.

Notes

The learning rate will never actually reach 0. Rather, it stays at the last value right before reaching 0. How small this value is, depends on the choice for max_steps. For max_steps = 1, it will stay at 1.0, for max_steps = 2, it will stay at 0.5, and so on.

__call__(step)[source]

Learning rate scaling factor depending on the step.

Parameters:

step (int) – Step to return the learning-rate scaling factor for.

Returns:

The learning-rate scaling factor

Return type:

float

Base classes

class TestDataBase[source]

Bases: ABC

abstract property n

The total number of data points.

abstractmethod sample(batch_size, max_n=None)[source]

Interator over batches from a reproducible sample of your data.

Needed to consistently evaluate (and report) the train and/or test error after every epoch of training. Every time sample is called with the same arguments, the exact same output should be returned so that train and test errors can be used to track training convergence.

Parameters:
  • batch_size (int) – Number of data points in (or batch_size of) the reproducible sample of your data. Every time sample is called with the same size, the exact same output should be returned so that train and test errors can be used to track training convergence.

  • max_n (int, optional) – Total number of data points to return mini-batches for. Defaults to None, resulting in all available data points being used. To save some time, however, you might want to not use all available training data points just for evaluating your model metrics.

Returns:

  • tuple – A tuple of input tensors to call your model with.

  • Tensor – The matching target values. Must have the same dimensions and sizes as the output of your model.

Important

Even if your model only takes a single tensor as input, this method must return a tuple to cover the general case of your model taking more than one tensor as input. If this is not the case, simple return a 1-tuple of tensors.

class TrainDataBase[source]

Bases: TestDataBase

abstractmethod __call__(batch_size, step_freq=1, epoch=0)[source]

Return an iterator over the mini-batches your model is trained on.

Parameters:
  • batch_size (int) – The (maximum) number of data points in one batch.

  • step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case. and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.

  • epoch (int, optional) – In rare cases, it might be useful to know the current epoch when deciding which batches to return. Should start with 1 in the first epoch and, therefore, defaults to 0.

Returns:

  • n_batches (int) – Total number of batches the returned iterator will provide. If unknown for some reason, also None can be returned.

  • batches (Iterator) – One element yielded by the iterator is a 2-tuple. The first element is again a tuple, containing the input tensor(s) to call your model with. The second element of the tuple is a single tensor with the matching target values. It must have the same dimensions and sizes as the output of your model.

Important

Even if your model only takes a single tensor as input, the first element of one tuple yielded by the returned iterator must always be a tuple of tensors to cover the general case of your model taking more than one tensor as input. If this is not the case, simple make that a 1-tuple of tensors!

adjust_batches_for(batch_size, step_freq=1, n=None)[source]

Number of batches reduced to be suitably integer-divisible.

This is a helper method for users to implement the __call__ method in the case of step_freq > 1. The returned number of batches is guaranteed to be integer-divisible by step_freq so that no batches are “left over” at the end of the epoch.

Parameters:
  • batch_size (int) – The desired number of data points in one batch.

  • step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.

  • n (int, optional) – In rare cases, it might be useful to pass in the number of data points to adjust rather than taking the number of data points returned by self.n, which is the default.

Returns:

Reduced number of batches that is guaranteed to be integer divisible by step_freq.

Return type:

int

adjust_n_for(batch_size, step_freq=1, n=None)[source]

Number of data points reduced to be suitably integer-divisible.

This is a helper method for users to implement the __call__ method in the case of step_freq > 1. Taking only the returned number of data points guarantees that all batches have the same size and that there will be no “left-over” batches at the end of the epoch.

Parameters:
  • batch_size (int) – The desired number of data points in one batch.

  • step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.

  • n (int, optional) – In rare cases, it might be useful to pass in the number of data points to adjust rather than taking the number of data points returned by self.n, which is the default.

Returns:

Reduced number of data points that is guaranteed to be integer divisible by the product of batch_size and step_freq.

Return type:

int

abstract property n

The total number of data points.

abstractmethod sample(batch_size, max_n=None)

Interator over batches from a reproducible sample of your data.

Needed to consistently evaluate (and report) the train and/or test error after every epoch of training. Every time sample is called with the same arguments, the exact same output should be returned so that train and test errors can be used to track training convergence.

Parameters:
  • batch_size (int) – Number of data points in (or batch_size of) the reproducible sample of your data. Every time sample is called with the same size, the exact same output should be returned so that train and test errors can be used to track training convergence.

  • max_n (int, optional) – Total number of data points to return mini-batches for. Defaults to None, resulting in all available data points being used. To save some time, however, you might want to not use all available training data points just for evaluating your model metrics.

Returns:

  • tuple – A tuple of input tensors to call your model with.

  • Tensor – The matching target values. Must have the same dimensions and sizes as the output of your model.

Important

Even if your model only takes a single tensor as input, this method must return a tuple to cover the general case of your model taking more than one tensor as input. If this is not the case, simple return a 1-tuple of tensors.

class StepCallback[source]

Bases: ABC

Base class to inherit from when implementing custom step callbacks.

abstractmethod __call__(train_loss, learning_rate, gradient_norm)[source]

Called after processing a batch to print, log, or save information.

Parameters:
  • train_loss (float) – The training loss on the current batch.

  • learning_rate (float) – The learning rate used in the current optimization step.

  • gradient_norm (float) – Norm of the current gradients.

abstractmethod close()[source]

Called at the end of training to be compatible with TensorBoard.

class EpochCallback[source]

Bases: ABC

Base class to inherit from when implementing custom epoch callbacks.

abstractmethod __call__(epoch, train_loss, test_loss, learning_rate, model, data)[source]

Called after each epoch to print, log, or otherwise analyse metrics.

Parameters:
  • epoch (int) – The current epoch in the training loop.

  • train_loss (float) – The loss computed on a sample of the train data after the current epoch.

  • test_loss (float) – The loss computed on a sample of the test data after the current epoch. Always nan if no test data is used.

  • learning_rate (float) – Learning rate of the optimizer in the current epoch.

  • model (Module) – A reference to the model being trained.

  • data (Batches) – An iterator over 2-tuples of a feature-tensor tuple and a target tensor to call the model with. Sampled from the test data if present and from the train data otherwise.

abstractmethod close()[source]

Called at the end of training to be compatible with TensorBoard.

class TrainCallback[source]

Bases: ABC

Base class to inherit from when implementing custom train callbacks.

abstractmethod __call__(epoch, best_epoch, best_loss, max_epochs_reached, history)[source]

Called after training has finished to print, log, or save a summary.

Parameters:
  • epoch (int) – The last epoch in the training loop.

  • best_epoch (int) – The epoch with the lowest loss encountered.

  • best_loss (float) – The lowest loss encountered.

  • max_epochs_reached (bool) – Whether the maximum number of epochs was exhausted or not.

  • history (History) – Dictionary with lists of train losses, test losses, and learning rates.