train
Flexibly parameterized and feature-rich training loop for models.
Allows for warm-up periods as well as early stopping with recovery of the best checkpoint from CPU memory or disk. Use the provided callback examples to monitor training progress or derive your own from the respective base classes. Custom templates for train and test data to seamlessly work with the training loop are also provided.
- class Trainer(loss, optimizer=Curry(AdamW), batch_size=64, max_epochs=100, scheduler=Curry(NoSchedule), warmup=0, batch_step=False, patience=None, max_n=None, step_freq=1, clip_grad=1.0, checkpoint=InMemory(), show_progress=True, step_cbs=(), cb_freq=1, epoch_cbs=EpochPrinter(print), train_cbs=TrainPrinter(print))[source]
Bases:
ArgRepr
Train and, optionally, evaluate a model with early stopping.
- Parameters:
loss (Module) – PyTorch module that accepts the output(s) of the model to train as first argument(s) and the target as last argument and that produces the (scalar) loss to minimize, i.e., reduction must be “mean” or “sum”.
optimizer (Curry[Optimizer], optional) – A curry of a preconfigured PyTorch Optimizer (or some other custom construct) that returns a fully configured PyTorch Optimizer when called with the
model.parameters()
of the model to train. Defaults toAdamW
with its default parameters.batch_size (int, optional) – Size of the mini-batches to request from the training data. Defaults to 64.
max_epochs (int, optional) – Maximum number of epochs to train for. Defaults to 100.
scheduler (Curry[LRScheduler], optional) – A curry of a PyTorch learning-rate scheduler (or some other custom construct) that returns a fully configured learning-rate scheduler when called with a PyTorch optimizer. Defaults to never changing the learning rate at all.
warmup (int, optional) – In the beginning of training, the learning-rate scheduler will be stepped after every optimizer step for this many times. Afterward, it will only be stepped at the end of each epoch. Defaults to 0, which results in no warmup and the learning-rate scheduler being stepped at the end of the first epoch for the first time.
batch_step (bool, optional) – Whether to step the learning-rate scheduler after each batch or after each epoch once the warmup period is over. Default to
False
.patience (int, optional) – If patience is not
None
and smaller than max_epochs, early stopping is active. A snapshot of the model’s state is taken after each epoch that improved the loss below its last minimum. If no improvement occurs for patience epochs, model training is stopped (even if max_epochs has not been reached yet) and the model is reset to its best state.max_n (int, optional) – Maximum number of data points to take from the training data to evaluate the train loss after each epoch. Defaults to number of data points in the test set (if present) or the train set (if not).
step_freq (int, optional) – For how many batches to accumulate gradients before taking an optimization step. Defaults to 1, which corresponds to no accumulation.
clip_grad (float, optional) – Clip gradients such that their overall norm is capped by the given value. Defaults to 1.0
checkpoint (Checkpoint, optional) – Whenever the train (or test) loss after an epoch is smaller than the loss after the last, a new snapshot of the model state is saved by calling the save method of the checkpoint instance. Defaults to
InMemory
.show_progress (bool, optional) – Whether to provide visual feedback to the console while training an epoch in the form of a progress bar. Defaults to
True
step_cbs (iterable of StepCallback, optional) – All per-step callbacks will be called every cb_freq batches with the train loss of the last batch and the current learning rate. Defaults to an empty tuple, which does nothing.
cb_freq (int, optional) – Number of batches to skip before calling step_cb again. Defaults to 1.
epoch_cbs (iterable of EpochCallback, optional) – All epoch callbacks will be called after each epoch with epoch, train loss, test loss, and current learning rate. Defaults to a single
EpochPrinter
.train_cbs (iterable of TrainCallback, optional) – All train callbacks will be called after training finished with last epoch, epoch with the best loss, the best loss itself, whether max_epochs was exhausted, and with the training history in the form of a dictionary of lists with train loss, test loss and learning rate. Defaults to a single
TrainPrinter
.
Important
- optimizer
Because the (partial) optimizer will simply be completed with
model.parameters()
, parameter groups are not supported.- scheduler
During warmup, the scheduler is called after each optimizer step and, afterward, at the end of each epoch. The scheduler thus needs to be aware of the warmup settings and act accordingly.
- step_freq
If the
reduction
of the loss is “mean”, the loss of each batch is divided by this number before performing the backward pass. Strictly speaking, each batch should thus have the exact same number of data points in this case. Furthermore, complications might arise when the model to train contains elements likeBatchNorm
.
See also
- resume(model, train, test=None)[source]
Resume model training from the best epoch checkpointed so far.
- Parameters:
model (Module) – PyTorch Module to train.
train (TrainDataBase) – Training data.
test (TestDataBase, optional) – Hold-out data to compute test loss. If configured, early stopping will track the test loss if test is given and the train loss if it is not.
- Returns:
The trained Pytorch model.
- Return type:
Module
Note
This is safe to use even if you have never trained your model before, and you are starting from scratch.
Important
The model to train must always return a tuple of tensors, not just a single tensor. If it produces only one tensor, return it as a 1-tuple of tensors.
- property scale
Scaling factor for the accumulated loss before the backward pass.
- train(model, train, test=None)[source]
Train a fresh model from scratch, starting from a clean slate.
- Parameters:
model (Resettable) – PyTorch Module to train. Must have a
reset_parameters()
method that can be called without any parameters to re-initialize all trainable model parameters and buffers.train (TrainDataBase) – Training data.
test (TestDataBase, optional) – Hold-out data to compute test loss. If configured, early stopping will track the test loss if test is given and the train loss if it is not.
- Returns:
The trained Pytorch model.
- Return type:
Module
- Raises:
TrainError – If the model does not have a
reset_parameters()
method.
Important
The model to train must always return a tuple of tensors, not just a single tensor. If it produces only one tensor, return it as a 1-tuple of tensors.
Warning
Model training is (re-)started from scratch every time this method is called. The training history is erased, all internal model parameters are reset to a pristine initial state, and previously saved checkpoints are irrevocably deleted!
- class InMemory[source]
Bases:
Checkpoint
Checkpoint in CPU memory regardless of the device training runs on.
- property counter
How many checkpoints were saved since the last reset.
- load(model, optimizer=None, scheduler=None)
User-facing method for loading a checkpoint during training.
The state retrieved from the checkpoint is merged into the state of the objects to update, such that loading a checkpoint before one has been saved results in unchanged objects.
- Parameters:
model (Module) – The model to load the state for in-place.
optimizer (Optimizer, optional) – The optimizer to load the state for in-place. Defaults to
None
.scheduler (LRScheduler, optional) – The scheduler to load the state for in-place. Defaults to
None
.
- Returns:
epoch (int) – The epoch of the checkpoint.
loss (float) – The loss from the checkpoint.
- reset_parameters()
Hard-resets the checkpoint into a pristine, initial state.
- save(epoch, loss, model, optimizer=None, scheduler=None)
User-facing method for saving a checkpoint during training.
- Parameters:
epoch (int) – The current epoch.
loss (float) – The (train or test) loss at the current epoch.
model (Module) – The model to checkpoint the state of.
optimizer (Optimizer, optional) – The optimizer to checkpoint the state of. Defaults to
None
.scheduler (LRScheduler, optional) – The scheduler to checkpoint the state of. Defaults to
None
.
- class OnDisk(path, create=False, not_found=NotFound.WARN)[source]
Bases:
Checkpoint
Checkpoint the current state of training on disk.
- Parameters:
path (str) – Full path to and name of the file used to persist state.
create (bool, optional) – What to do if the directory where the checkpoint should be saved does not exist. Defaults to
False
.not_found (str, optional) – What to do if a checkpoint is loaded from the specified path but none is there yet. One of “ignore”, “warn”, or “raise”. Defaults to “raise”. Use the
NotFound
enum to avoid typos!
See also
- property counter
How many checkpoints were saved since the last reset.
- load(model, optimizer=None, scheduler=None)
User-facing method for loading a checkpoint during training.
The state retrieved from the checkpoint is merged into the state of the objects to update, such that loading a checkpoint before one has been saved results in unchanged objects.
- Parameters:
model (Module) – The model to load the state for in-place.
optimizer (Optimizer, optional) – The optimizer to load the state for in-place. Defaults to
None
.scheduler (LRScheduler, optional) – The scheduler to load the state for in-place. Defaults to
None
.
- Returns:
epoch (int) – The epoch of the checkpoint.
loss (float) – The loss from the checkpoint.
- reset_parameters()
Hard-resets the checkpoint into a pristine, initial state.
- save(epoch, loss, model, optimizer=None, scheduler=None)
User-facing method for saving a checkpoint during training.
- Parameters:
epoch (int) – The current epoch.
loss (float) – The (train or test) loss at the current epoch.
model (Module) – The model to checkpoint the state of.
optimizer (Optimizer, optional) – The optimizer to checkpoint the state of. Defaults to
None
.scheduler (LRScheduler, optional) – The scheduler to checkpoint the state of. Defaults to
None
.
- class StepPrinter(printer=<built-in function print>, sep=', ')[source]
Bases:
ArgRepr
,StepCallback
Step callback assembling a one-liner on per-batch training progress.
- Parameters:
printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin
print
function, but could also be a logging command.sep (str, optional) – The items concatenated into a one-line message will be separated by this string. Default to “ “.
- class EpochPrinter(printer=<built-in function print>)[source]
Bases:
ArgRepr
,EpochCallback
Epoch callback assembling an informative message on training progress.
- Parameters:
printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin
print
function, but could also be a logging command.
- class TrainPrinter(printer=<built-in function print>)[source]
Bases:
ArgRepr
,TrainCallback
Train callback assembling an informative message when training ends.
- Parameters:
printer (callable, optional) – Will be called with the assembled message. Defaults to the python builtin
print
function, but could also be a logging command.
- __call__(epoch, best_epoch, best_loss, max_epochs_reached, history)[source]
Assemble a summary of model training and call the printer with it.
- Parameters:
epoch (int) – The last epoch in the training loop.
best_epoch (int) – The epoch with the lowest loss encountered.
best_loss (float) – The lowest loss encountered.
max_epochs_reached (bool) – Whether the maximum number of epochs was exhausted or not.
history (History) – Dictionary with lists of train losses, test losses, and learning rates.
- class History[source]
Bases:
TypedDict
Summary of metrics passed to the callback when training is finished.
- lr
List of learning rates.
- test_loss
List of losses evaluated on test data.
- train_loss
List of losses evaluated on train data.
Schedulers
- class LinearInverse(warmup=0, power=0.5)[source]
Bases:
ArgRepr
Scale up learning rate during warmup before decaying with inverse power.
Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as
lr_lambda
argument to PyTorch’s LambdaLR learning-rate scheduler.- Parameters:
warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.
power (float, optional) – After warmup steps, the learning rate is scaled down with the inverse step number taken to the power of this number. Values are cut to lie in the interval [0.5, 1.0]. Defaults to 0.5, the slowest decay.
- class LinearExponential(warmup=0, gamma=0.95)[source]
Bases:
ArgRepr
Scale up learning rate during warmup before decaying exponentially.
Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as
lr_lambda
argument to PyTorch’s LambdaLR learning-rate scheduler.- Parameters:
warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.
gamma (float, optional) – After warmup steps, the learning rate is scaled down with this number to the power of the step number. Therefore, it must lie in the interval (0.0, 1.0). Defaults to 0.95.
- class LinearCosine(warmup=0, cooldown=100)[source]
Bases:
ArgRepr
Scale up learning rate during warmup before decaying with cosine.
Instances of this class are not learning-rate schedulers by themselves! They are intended to be passed as
lr_lambda
argument to PyTorch’s LambdaLR learning-rate scheduler.- Parameters:
warmup (int, optional) – Number of steps during which the learning rate will be linearly scaled up to the one specified in the optimizer. Defaults to 0, resulting in the learning rate already at its maximum value for the first step and only decaying thereafter.
cooldown (int, optional) – After warmup steps, the learning rate is scaled down with a cosine function for this many steps. Defaults to 100, but must be at least 1.
Notes
The learning rate will never actually reach 0. Rather, it stays at the last value right before reaching 0. How small this value is, depends on the choice for max_steps. For max_steps = 1, it will stay at 1.0, for max_steps = 2, it will stay at 0.5, and so on.
Base classes
- class TestDataBase[source]
Bases:
ABC
- abstract property n
The total number of data points.
- abstractmethod sample(batch_size, max_n=None)[source]
Interator over batches from a reproducible sample of your data.
Needed to consistently evaluate (and report) the train and/or test error after every epoch of training. Every time sample is called with the same arguments, the exact same output should be returned so that train and test errors can be used to track training convergence.
- Parameters:
batch_size (int) – Number of data points in (or batch_size of) the reproducible sample of your data. Every time sample is called with the same size, the exact same output should be returned so that train and test errors can be used to track training convergence.
max_n (int, optional) – Total number of data points to return mini-batches for. Defaults to
None
, resulting in all available data points being used. To save some time, however, you might want to not use all available training data points just for evaluating your model metrics.
- Returns:
tuple – A tuple of input tensors to call your model with.
Tensor – The matching target values. Must have the same dimensions and sizes as the output of your model.
Important
Even if your model only takes a single tensor as input, this method must return a tuple to cover the general case of your model taking more than one tensor as input. If this is not the case, simple return a 1-tuple of tensors.
- class TrainDataBase[source]
Bases:
TestDataBase
- abstractmethod __call__(batch_size, step_freq=1, epoch=0)[source]
Return an iterator over the mini-batches your model is trained on.
- Parameters:
batch_size (int) – The (maximum) number of data points in one batch.
step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case. and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.
epoch (int, optional) – In rare cases, it might be useful to know the current epoch when deciding which batches to return. Should start with 1 in the first epoch and, therefore, defaults to 0.
- Returns:
n_batches (int) – Total number of batches the returned iterator will provide. If unknown for some reason, also
None
can be returned.batches (Iterator) – One element yielded by the iterator is a 2-tuple. The first element is again a tuple, containing the input tensor(s) to call your model with. The second element of the tuple is a single tensor with the matching target values. It must have the same dimensions and sizes as the output of your model.
Important
Even if your model only takes a single tensor as input, the first element of one tuple yielded by the returned iterator must always be a tuple of tensors to cover the general case of your model taking more than one tensor as input. If this is not the case, simple make that a 1-tuple of tensors!
- adjust_batches_for(batch_size, step_freq=1, n=None)[source]
Number of batches reduced to be suitably integer-divisible.
This is a helper method for users to implement the
__call__
method in the case of step_freq > 1. The returned number of batches is guaranteed to be integer-divisible by step_freq so that no batches are “left over” at the end of the epoch.- Parameters:
batch_size (int) – The desired number of data points in one batch.
step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.
n (int, optional) – In rare cases, it might be useful to pass in the number of data points to adjust rather than taking the number of data points returned by
self.n
, which is the default.
- Returns:
Reduced number of batches that is guaranteed to be integer divisible by step_freq.
- Return type:
int
- adjust_n_for(batch_size, step_freq=1, n=None)[source]
Number of data points reduced to be suitably integer-divisible.
This is a helper method for users to implement the
__call__
method in the case of step_freq > 1. Taking only the returned number of data points guarantees that all batches have the same size and that there will be no “left-over” batches at the end of the epoch.- Parameters:
batch_size (int) – The desired number of data points in one batch.
step_freq (int, optional) – In case this number is > 1, the optimizer will accumulate gradients for that many batches before taking a step. All batches should be of the same size in this case and there shouldn’t be any “left-over” batches at the end of each epoch. Defaults to 1.
n (int, optional) – In rare cases, it might be useful to pass in the number of data points to adjust rather than taking the number of data points returned by
self.n
, which is the default.
- Returns:
Reduced number of data points that is guaranteed to be integer divisible by the product of batch_size and step_freq.
- Return type:
int
- abstract property n
The total number of data points.
- abstractmethod sample(batch_size, max_n=None)
Interator over batches from a reproducible sample of your data.
Needed to consistently evaluate (and report) the train and/or test error after every epoch of training. Every time sample is called with the same arguments, the exact same output should be returned so that train and test errors can be used to track training convergence.
- Parameters:
batch_size (int) – Number of data points in (or batch_size of) the reproducible sample of your data. Every time sample is called with the same size, the exact same output should be returned so that train and test errors can be used to track training convergence.
max_n (int, optional) – Total number of data points to return mini-batches for. Defaults to
None
, resulting in all available data points being used. To save some time, however, you might want to not use all available training data points just for evaluating your model metrics.
- Returns:
tuple – A tuple of input tensors to call your model with.
Tensor – The matching target values. Must have the same dimensions and sizes as the output of your model.
Important
Even if your model only takes a single tensor as input, this method must return a tuple to cover the general case of your model taking more than one tensor as input. If this is not the case, simple return a 1-tuple of tensors.
- class StepCallback[source]
Bases:
ABC
Base class to inherit from when implementing custom step callbacks.
- abstractmethod __call__(train_loss, learning_rate, gradient_norm)[source]
Called after processing a batch to print, log, or save information.
- Parameters:
train_loss (float) – The training loss on the current batch.
learning_rate (float) – The learning rate used in the current optimization step.
gradient_norm (float) – Norm of the current gradients.
- class EpochCallback[source]
Bases:
ABC
Base class to inherit from when implementing custom epoch callbacks.
- abstractmethod __call__(epoch, train_loss, test_loss, learning_rate, model, data)[source]
Called after each epoch to print, log, or otherwise analyse metrics.
- Parameters:
epoch (int) – The current epoch in the training loop.
train_loss (float) – The loss computed on a sample of the train data after the current epoch.
test_loss (float) – The loss computed on a sample of the test data after the current epoch. Always
nan
if no test data is used.learning_rate (float) – Learning rate of the optimizer in the current epoch.
model (Module) – A reference to the model being trained.
data (Batches) – An iterator over 2-tuples of a feature-tensor tuple and a target tensor to call the model with. Sampled from the test data if present and from the train data otherwise.
- class TrainCallback[source]
Bases:
ABC
Base class to inherit from when implementing custom train callbacks.
- abstractmethod __call__(epoch, best_epoch, best_loss, max_epochs_reached, history)[source]
Called after training has finished to print, log, or save a summary.
- Parameters:
epoch (int) – The last epoch in the training loop.
best_epoch (int) – The epoch with the lowest loss encountered.
best_loss (float) – The lowest loss encountered.
max_epochs_reached (bool) – Whether the maximum number of epochs was exhausted or not.
history (History) – Dictionary with lists of train losses, test losses, and learning rates.