Source code for swak.pt.misc.shape

"""Convenient classes (and functions) that do not fit any other category."""

from typing import Any, Self, overload
from collections.abc import Iterable
from functools import cached_property, singledispatchmethod
from itertools import chain
import torch as pt
from ...misc import ArgRepr
from ..types import Tensor, Tensors
from ..exceptions import ShapeError, DeviceError, DTypeError, ValidationErrors


[docs] class Stack(ArgRepr): """Simple partial of PyTorch's top-level `stack` function. Parameters ---------- dim: int, optional The new dimension along which to stack the tensors. Defaults to 0. """ def __init__(self, dim: int = 0) -> None: super().__init__(dim) self.dim = dim
[docs] def __call__(self, tensors: Tensors | list[Tensor]) -> Tensor: """Concatenate the given tensors along one of their dimensions. Parameters ---------- tensors: tuple or list of tensors Tensors to stack along a new dimension. Returns ------- Tensor The stacked tensors. """ return pt.stack(tensors, dim=self.dim)
[docs] class Cat(ArgRepr): """Simple partial of PyTorch's top-level `cat` function. Parameters ---------- dim: int, optional The dimension along which to concatenate the tensors. Defaults to 0. """ def __init__(self, dim: int = 0) -> None: super().__init__(dim) self.dim = dim
[docs] def __call__(self, tensors: Tensors | list[Tensor]) -> Tensor: """Concatenate the given tensors along one of their dimensions. Parameters ---------- tensors: tuple or list of tensors Tensors to concatenate along an existing dimension. Returns ------- Tensor The concatenated tensors. """ return pt.cat(tensors, dim=self.dim)
[docs] class LazyCatDim0: """Lazily concatenate a sequence of tensors along their first dimension. Concatenating a large number of even small tensors (or a small number of large tensors) causes a memory spike because, temporarily, two copies of all tensors are needed. Sometimes, this simply cannot be avoided. However, when only a small part of the full concatenation of all tensors is needed at any given time, e.g., when chopping off micro-batches of training data to feed to a model, the present class provides an alternative: Constituent tensors are kept as is and concatenation is only performed when slices or element(s) along the first dimension are requested. These are selected from the constituents first and only then concatenated (and, thus, copied). Slicing and element selection of further dimensions is delayed until selection and concatenation along the first dimension is completed. Parameters ---------- tensors: iterable The iterable of tensors to cache, all of which must have the same number of dimensions and the exact same sizes in all dimensions but the first. Raises ------ ShapeError If any tensor is a scalar, that is, has zero dimensions, or if the shape after the first dimension is not the same across all tensors. Also, if there are no tensors to wrap. DeviceError If tensors are spread over multiple devices. DTypeError If tensors have multiple dtypes. """ def __init__(self, tensors: Iterable[Tensor]) -> None: self.__tensors = self._valid(tuple(tensors)) @staticmethod def _valid(tensors: Tensors) -> Tensors: """Run a few validations on the homogeneity of the wrapped tensors.""" if not tensors: msg = 'Expected a non-empty iterable of tensors!' raise ShapeError(msg) errors = [] if any(tensor.dim() == 0 for tensor in tensors): msg = 'Scalar tensors can not be concatenated!' errors.append(ShapeError(msg)) _, *shape = tensors[0].shape if any(list(tensor.shape[1:]) != shape for tensor in tensors[1:]): msg = 'All tensors must be of shape {} after the first dimension!' errors.append(ShapeError(msg.format(shape))) if any(tensor.device != tensors[0].device for tensor in tensors[1:]): msg = 'All tensors must be on the same device!' errors.append(DeviceError(msg)) if any(tensor.dtype is not tensors[0].dtype for tensor in tensors[1:]): msg = 'All tensors must have the same dtype!' errors.append(DTypeError(msg)) if errors: raise ValidationErrors('Validation failed', errors) return tensors @cached_property def lookup(self) -> tuple[tuple[int, int], ...]: """A lazily computed and cached lookup table for indices.""" return tuple( (i, j) for i, tensor in enumerate(self.__tensors) for j in range(tensor.size(0)) ) def __repr__(self) -> str: cls = self.__class__.__name__ return f'{cls}(n={len(self.lookup)})' def __len__(self) -> int: return len(self.lookup) def __iter__(self) -> chain[Tensor]: return chain.from_iterable(self.__tensors) def __contains__(self, elem: Any) -> bool: return any(elem in tensor for tensor in self.__tensors) @singledispatchmethod def __getitem__(self, index) -> Tensor: idx, elem = self.lookup[index] return self.__tensors[idx][elem] @__getitem__.register def _(self, index: slice) -> Tensor: return pt.cat([ self.__tensors[idx][elem:elem + 1] for idx, elem in self.lookup[index.start:index.stop] ])[::index.step] @__getitem__.register def _(self, index: list) -> Tensor: return pt.cat([ self.__tensors[idx][elem:elem + 1] for idx, elem in (self.lookup[idx] for idx in index) ]) @__getitem__.register def _(self, index: tuple) -> Tensor: if not index: return self[:] idx, *indices = index match idx: case slice(): return self[idx][:, *indices] case int(): return self[idx][*indices] @property def dtype(self) -> pt.dtype: """The common dtype of all cached tensors.""" return self.__tensors[0].dtype @property def device(self) -> pt.device: """The common device of all cached tensors.""" return self.__tensors[0].device @cached_property def shape(self) -> pt.Size: """The shape of the full concatenation of the wrapped tensors.""" _, *shape = self.__tensors[0].shape return pt.Size([len(self), *shape]) @overload def size(self, dim: int) -> int: ... @overload def size(self, dim: None = None) -> pt.Size: ...
[docs] def size(self, dim = None): """The size of the full concatenation of the wrapped tensors. Parameters ---------- dim: int, optional The dimension for which to return the size. Defaults to ``None``. If not given the `shape` of the full concatenation of all cached tensors is returned. Returns ------- int If the size along a certain dimension was requested. Size A PyTorch ``Size`` object specifying the `shape` of the full concatenation of all cached tensors. """ return self.shape[dim] if dim is not None else self.shape
[docs] def to(self, *args: Any) -> Self: """Returns a new instance of self, wrapping transformed tensors. See the PyTorch `documentation <https://pytorch.org/docs/stable/ generated/torch.Tensor.to.html#torch-tensor-to>`_ for possible call signatures, but keep in mind that all listed operations will create a copy of the wrapped tensors after all. """ return self.__class__(tensor.to(*args) for tensor in self.__tensors)