io
Save and load (the state of) PyTorch models to and from any filesystem.
- class ModelLoader(path='', storage=Storage.FILE, chunk_size=32, storage_kws=None, map_location=None)[source]
Bases:
ReaderLoad a previously saved model from any supported filesystem.
- Parameters:
path (str, optional) – Full path to the model to load. Since it (or part of it) can also be provided later, when the callable instance is called, it is optional here. Defaults to an empty string.
storage (str, optional) – The type of file system to read from (“file”, “s3”, etc.). Defaults to “file”. Use the
Storageenum to avoid typos.chunk_size (float, optional) – Chunk size to use when reading from the selected file system in MiB. Defaults to 32 (MiB).
storage_kws (dict, optional) – Passed on as keyword arguments to the constructor of the file system.
map_location (str or Device, optional) – The device to load the modelo onto. Defaults to
Nonewhich loads to the PyTorch default device.
- Raises:
TypeError – If path is not a string, chunk_size is not a float, or if storage_kws is not a dictionary.
ValueError – If storage is not among the currently supported file-system schemes, mode not among the supported file-mode options, the chunk_size is smaller than 1 (MiB), or if storage_kws is not a dictionary.
See also
- __call__(path='')[source]
Load a previously saved model from any supported filesystem.
- Parameters:
path (str, optional) – Path (including file name) to the model file to load. If it starts with a backslash, it will be interpreted as absolute, if not, as relative to the path specified at instantiation. Defaults to an empty string, which results in an unchanged path.
- Returns:
The loaded model.
- Return type:
Module
- Raises:
ValueError – If the final path is directly under root (e.g., “/file.pt”) because, on local file system, this is not where you want to save to and, on object storage, the first directory refers to the name of an (existing!) bucket.
- class ModelSaver(path, storage=Storage.FILE, overwrite=False, skip=False, chunk_size=32, storage_kws=None)[source]
Bases:
WriterSave an entire model to any supported file system.
- Parameters:
path (str) – The absolute path to the file to save the model to. May include two or more forward slashes (subdirectories will be created) and string placeholders (i.e., pairs of curly brackets) that will be interpolated when instances are called.
storage (str, optional) – The type of file system to write to (“file”, “s3”, etc.). Defaults to “file”. Use the
Storageenum to avoid typos.overwrite (bool, optional) – Whether to silently overwrite the destination file. Defaults to
False, which will raise an exception if it already exists.skip (bool, optional) – Whether to silently do nothing if the target file already exists. Defaults to
False.chunk_size (float, optional) – Chunk size to use when writing to the selected file system in MiB. Defaults to 32 (MiB).
storage_kws (dict, optional) – Passed on as keyword arguments to the constructor of the file system.
- Raises:
TypeError – If path is not a string, chunk_size is not a float, or if storage_kws is not a dictionary.
ValueError – If storage is not among the currently supported file-system schemes, mode not among the supported file-mode options, the chunk_size is smaller than 1 (MiB), or if storage_kws is not a dictionary.
See also
- __call__(model, *parts)[source]
Save a model to file on any supported file system.
- Parameters:
model (Module) – The model to save.
*parts (str, optional) – Fragments that will be interpolated into the path string given at instantiation. Obviously, there must be at least as many as there are placeholders in the path.
- Returns:
An empty tuple.
- Return type:
tuple
- Raises:
IndexError – If the path given at instantiation has more string placeholders that there are parts.
FileExistsError – If the destination file already exists, skip is
Falseand overwrite is alsoFalse.ValueError – If the final path is directly under root (e.g., “/file.pt”) because, on local file system, this is not where you want to save to and, on object storage, the first directory refers to the name of an (existing!) bucket.
- class StateLoader(path, storage=Storage.FILE, chunk_size=32, storage_kws=None, map_location=None, merge=True, not_found=NotFound.RAISE)[source]
Bases:
ReaderLoad the state of a model from any supported file system.
- Parameters:
path (str) – Full path to the file that holds the model’s
state_dict(). May include two or more forward slashes (subdirectories will be created) and string placeholders (i.e., pairs of curly brackets) that will be interpolated when instances are called.storage (str, optional) – The type of file system to read from (“file”, “s3”, etc.). Defaults to “file”. Use the
Storageenum to avoid typos.chunk_size (float, optional) – Chunk size to use when reading from the selected file system in MiB. Defaults to 32 (MiB).
storage_kws (dict, optional) – Passed on as keyword arguments to the constructor of the file system.
map_location (str or Device, optional) – The device to load the state onto. Defaults to
Nonewhich loads to the PyTorch device(s) that were saved with the model.merge (bool, optional) – Whether the loaded state should be merged into the state of the model (
True) or replace its state (False). This allows loading only a partial state. Defaults toTrue.not_found (str, optional) – What to do if the specified file is not found. One of “ignore”, “warn”, or “raise” (use the
NotFoundenum to avoid typos). Defaults to “raise”. If set to “ignore” or “warn” and the specified file is not found, merge is overridden toTrue, thus returning the unaltered model.
- Raises:
TypeError – If path is not a string, chunk_size is not an integer or either storage_kws or parquet_kws are not dictionaries.
ValueError – If storage is not among the currently supported file-system schemes, mode not among the supported file-mode options, not_found not a permitted string, if the chunk_size is smaller than 1 (MiB), or if storage_kws is not a dictionary.
- __call__(model, *parts)[source]
Load the state of a model from file on any supported filesystem.
- Parameters:
model (Module) – Model to load the state of.
*parts (str, optional) – Fragments that will be interpolated into the path string given at instantiation. Obviously, there must be at least as many as there are placeholders in the path.
- Returns:
The model with its state restored.
- Return type:
Module
- Raises:
ValueError – If the final path is directly under root (e.g., “/file.pt”) because, on local file system, this is not where you want to save to and, on object storage, the first directory refers to the name of an (existing!) bucket.
RuntimeError – If merge set to
Falseand the loaded state has fewer keys than the state of the model or if the loaded state has more keys than the model.
- class StateSaver(path, storage=Storage.FILE, overwrite=False, skip=False, chunk_size=32, storage_kws=None)[source]
Bases:
WriterSave the state of a model to any supported file system.
- Parameters:
path (str) – The absolute path to the file to save a model’s
state_dict()to. May include two or more forward slashes (subdirectories will be created) and string placeholders (i.e., pairs of curly brackets) that will be interpolated when instances are called.storage (str, optional) – The type of file system to write to (“file”, “s3”, etc.). Defaults to “file”. Use the
Storageenum to avoid typos.overwrite (bool, optional) – Whether to silently overwrite the destination file. Defaults to
False, which will raise an exception if it already exists.skip (bool, optional) – Whether to silently do nothing if the target file already exists. Defaults to
False.chunk_size (float, optional) – Chunk size to use when writing to the selected file system in MiB. Defaults to 32 (MiB).
storage_kws (dict, optional) – Passed on as keyword arguments to the constructor of the file system.
- Raises:
TypeError – If path is not a string, chunk_size is not a float, or if storage_kws is not a dictionary.
ValueError – If storage is not among the currently supported file-system schemes, mode not among the supported file-mode options, the chunk_size is smaller than 1 (MiB), or if storage_kws is not a dictionary.
See also
- __call__(model, *parts)[source]
Save the state of a model, optimizer, or scheduler to a file.
- Parameters:
model (Module) – Model to save the state of.
*parts (str, optional) – Fragments that will be interpolated into the path string given at instantiation. Obviously, there must be at least as many as there are placeholders in the path.
- Returns:
An empty tuple.
- Return type:
tuple
- Raises:
TypeError – If the model does no have a callable state_dict method.
IndexError – If the path given at instantiation has more string placeholders that there are parts.
FileExistsError – If the destination file already exists, skip is
Falseand overwrite is alsoFalse.ValueError – If the final path is directly under root (e.g., “/file.pt”) because, on local file system, this is not where you want to save to and, on object storage, the first directory refers to the name of an (existing!) bucket.