Source code for swak.pt.train.callbacks

from typing import TypedDict, Any
from collections.abc import Iterator
from abc import ABC, abstractmethod
from collections.abc import Callable
from ...misc import ArgRepr
from ..types import Module, Batches

__all__ = [
    'StepCallback',
    'StepPrinter',
    'EpochCallback',
    'EpochPrinter',
    'History',
    'TrainCallback',
    'TrainPrinter'
]


[docs] class History(TypedDict): """Summary of metrics passed to the callback when training is finished.""" train_loss: list[float] #: List of losses evaluated on train data. test_loss: list[float | None] #: List of losses evaluated on test data. lr: list[float] #: List of learning rates.
[docs] class StepCallback(ABC): """Base class to inherit from when implementing custom step callbacks."""
[docs] @abstractmethod def __call__( self, train_loss: float, learning_rate: float, gradient_norm: float ) -> None: """Called after processing a batch to print, log, or save information. Parameters ---------- train_loss: float The training loss on the current batch. learning_rate: float The learning rate used in the current optimization step. gradient_norm: float Norm of the current gradients. """ ...
[docs] @abstractmethod def close(self) -> None: """Called at the end of training to be compatible with TensorBoard.""" ...
[docs] class EpochCallback(ABC): """Base class to inherit from when implementing custom epoch callbacks."""
[docs] @abstractmethod def __call__( self, epoch: int, train_loss: int, test_loss: int, learning_rate: float, model: Module, data: Batches, ) -> None: """Called after each epoch to print, log, or otherwise analyse metrics. Parameters ---------- epoch: int The current epoch in the training loop. train_loss: float The loss computed on a sample of the train data after the current `epoch`. test_loss: float The loss computed on a sample of the test data after the current `epoch`. Always ``nan`` if no test data is used. learning_rate: float Learning rate of the optimizer in the current `epoch`. model: Module A reference to the model being trained. data: Batches An iterator over 2-tuples of a feature-tensor tuple and a target tensor to call the `model` with. Sampled from the test data if present and from the train data otherwise. """ ...
[docs] @abstractmethod def close(self) -> None: """Called at the end of training to be compatible with TensorBoard.""" ...
[docs] class TrainCallback(ABC): """Base class to inherit from when implementing custom train callbacks."""
[docs] @abstractmethod def __call__( self, epoch: int, best_epoch: int, best_loss: float, max_epochs_reached: bool, history: History ) -> None: """Called after training has finished to print, log, or save a summary. Parameters ---------- epoch: int The last epoch in the training loop. best_epoch: int The epoch with the lowest loss encountered. best_loss: float The lowest loss encountered. max_epochs_reached: bool Whether the maximum number of epochs was exhausted or not. history: History Dictionary with lists of train losses, test losses, and learning rates. """ ...
[docs] class StepPrinter(ArgRepr, StepCallback): """Step callback assembling a one-liner on per-batch training progress. Parameters ---------- printer: callable, optional Will be called with the assembled message. Defaults to the python builtin ``print`` function, but could also be a logging command. sep: str, optional The items concatenated into a one-line message will be separated by this string. Default to " ". """ def __init__( self, printer: Callable[[str], Any] = print, sep: str = ', ' ) -> None: super().__init__(printer, sep) self.printer = printer self.sep = sep def __iter__(self) -> Iterator['StepPrinter']: return iter([self])
[docs] def __call__( self, train_loss: float, learning_rate: float, gradient_norm: float ) -> None: """Assemble one-liner with loss, lr, and norm, and call the printer.""" msg = self.sep.join([ f'{train_loss:7.5f}', f'{learning_rate:7.5f}', f'{gradient_norm:7.5f}' ]) self.printer(msg)
[docs] def close(self) -> None: """Does nothing because there is nothing to close."""
[docs] class EpochPrinter(ArgRepr, EpochCallback): """Epoch callback assembling an informative message on training progress. Parameters ---------- printer: callable, optional Will be called with the assembled message. Defaults to the python builtin ``print`` function, but could also be a logging command. """ def __init__(self, printer: Callable[[str], Any] = print) -> None: super().__init__(printer) self.printer = printer def __iter__(self) -> Iterator['EpochPrinter']: return iter([self])
[docs] def __call__( self, epoch: int, train_loss: float, test_loss: float, learning_rate: float, model: Module, data: Batches ) -> None: """Assemble training-progress message and call the printer with it.""" msg = (f'Epoch: {epoch:>3} | learning rate: {learning_rate:7.5f} | ' f'train loss: {train_loss:7.5f} | test loss: {test_loss:7.5f}') self.printer(msg)
[docs] def close(self) -> None: """Does nothing because there is nothing to close."""
[docs] class TrainPrinter(ArgRepr, TrainCallback): """Train callback assembling an informative message when training ends. Parameters ---------- printer: callable, optional Will be called with the assembled message. Defaults to the python builtin ``print`` function, but could also be a logging command. """ def __init__(self, printer: Callable[[str], Any] = print) -> None: super().__init__(printer) self.printer = printer def __iter__(self) -> Iterator['TrainPrinter']: return iter([self])
[docs] def __call__( self, epoch: int, best_epoch: int, best_loss: float, max_epochs_reached: bool, history: History ) -> None: """Assemble a summary of model training and call the printer with it. Parameters ---------- epoch: int The last epoch in the training loop. best_epoch: int The epoch with the lowest loss encountered. best_loss: float The lowest loss encountered. max_epochs_reached: bool Whether the maximum number of epochs was exhausted or not. history: History Dictionary with lists of train losses, test losses, and learning rates. """ if max_epochs_reached: self.printer(f'Maximum number of {epoch} epochs exhausted!') else: msg = (f'Stopping after {epoch} epochs because, even after waiting' f' for {epoch - best_epoch} more epochs, the loss did not ' f'drop below the lowest value of {best_loss:7.5f} seen in ' f'epoch {best_epoch}. Recovering that checkpoint.') self.printer(msg)