transformer

Custom components for modern transformer architectures.

These include, but are not limited to, sinusoidal anr rotary position encodings (RoPE), flexible grouped-query self-attention, and full-blown encoder-only transformer using these components.

class EncoderLayer(attention, feed_forward, pos_enc=None, dropout=0.1, norm_first=True, norm_cls=<class 'torch.nn.modules.normalization.LayerNorm'>, *args, device='cpu', dtype=torch.float32, **kwargs)[source]

Bases: Trafo

Encoder layer (i.e., self-attention only) to use in a transformer.

Parameters:
  • attention (Attn) – A suitably parameterized instance of a self-attention block, typically MultiheadedSelfAttention or GroupedQuerySelfAttention.

  • feed_forward (Block) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a new() method to make fresh copies of itself,

    • processes tensors with dimensions (…, S, D),

    where S is the sequence length and D is the model dimension specified in the attention.

  • pos_enc (PosEnc, optional) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a new() method to make fresh copies of itself,

    • has a context attribute specifying the maximum sequence length,

    • processes tensors with dimensions (…, S, D),

    where S is the sequence length and D is the model dimension specified in the attention. If given, it will be called on the input tensor first thing. Typically, this would be an instance of Sinusoidal or Learnable positional encodings. Defaults to an instance of IdentityBlock, which does nothing.

  • dropout (float, optional) – Fraction of dropout to apply after self-attention and feed-forward. Defaults to 0.1

  • norm_first (bool, optional) – Whether to normalize inputs to attention and feed-forward or the sum of respective inputs and outputs. Defaults to True.

  • norm_cls (type, optional) – Which type of norm to use between (sub-)layers. Must be one of torch.nn.LayerNorm (the default) or torch.nn.RMSNorm.

  • *args – Arguments used to initialize an instance of norm_cls.

  • device (str or torch.device, optional) – Torch device to first create the encoder layer on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to first create the layer in. Defaults to torch.float.

  • **kwargs – Keyword arguments used to initialize an instance of norm_cls.

property context

Maximum context length given by the positional encodings.

property device

The device all weights, biases, activations, etc. reside on.

property dtype

The dtype of all weights, biases, activations, and parameters.

forward(src, mask=None, is_causal=True)[source]

Forward pass of one encoder layer (i.e., with self.attention only).

Parameters:
  • src (Tensor) – Input sequence(s) of dimensions (…, S, D), with sequence length S and model dimension D.

  • mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of True indicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults to None.

  • is_causal (bool, optional) – If set to True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default to True.

Returns:

The output has the same shape as the input.

Return type:

Tensor

Important

In adhering to the convention of the scaled_dot_product_attention, the meaning of True and False (attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the Transformer. Therefore, to stay compatible, use float masks!

property has_pos_enc

Whether positional encodings are applied.

property mod_dim

The model dimension.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

EncoderLayer

reset_parameters()[source]

Reset all internal parameters of the layer.

class Encoder(layer, n_layers=1, pos_enc=None, dropout=0.0, device='cpu', dtype=torch.float32)[source]

Bases: Trafo

Flexible transformer encoder.

Parameters:
  • layer (EncoderLayer) – A suitably parameterized instance of EncoderLayer.

  • n_layers (int, optional) – How often the layer is repeated in the transformer stack. Must be at least 1, the default.

  • pos_enc (PosEnc, optional) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a context attribute specifying the maximum sequence length,

    • processes tensors with dimensions (…, S, D),

    where S is the sequence length and D is the model dimension specified in the layer. If given, it will be called on the input tensor first thing. Typically, this would be an instance of Sinusoidal or Learnable positional encodings. Defaults to an instance of IdentityBlock, which does nothing.

  • dropout (float, optional) – Apply dropout to the sum of token embeddings and positional encodings with this probability during training. Defaults to 0.

  • device (str or torch.device, optional) – Torch device to first create the transformer on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to first create the transformer encoder stack in. Defaults to torch.float.

Raises:

ValueError – If n_layers is less than 1.

Note

If the layer sets norm_first to True, no norm is applied to the final output of the Encoder. If a trailing norm is desired, it should be applied externally, after this module.

property context

Maximum context length permitted by the positional encodings.

property device

The device of all weights, biases, activations, etc. reside on.

property dtype

The dtype of all weights, biases, activations, and parameters.

forward(src, attn_mask=None, src_mask=None, is_causal=True)[source]

Forward pass through the transformer encoder with optional masking.

Parameters:
  • src (Tensor) – Input sequence(s) of token embedding. Expected dimensions are (…, S, D), with S the sequence length and D the model dimension.

  • attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to. Defaults to None.

  • src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to. Defaults to None.

  • is_causal (bool, optional) – If set to True, inputs are masked with a causal S x S triangular matrix (as produced by generate_square_subsequent_mask) and both attn_mask and src_mask are ignored. Defaults to True.

Returns:

Transformed input with again dimensions (…, S, D).

Return type:

Tensor

Important

Boolean attention masks are not accepted!

property has_pos_enc

Whether positional encodings are applied.

static merge_masks(attn_mask, src_mask, is_causal)[source]

Utility method to merge attention and source masks if necessary.

Parameters:
  • attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to.

  • src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to.

  • is_causal (bool, optional) – If True, inputs are masked with a causal S x S triangular matrix and both attn_mask and src_mask are ignored.

Returns:

The merged masks, or None if none are provided or is_causal is True.

Return type:

Tensor or None

property mod_dim

The model dimension.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

Encoder

reset_parameters()[source]

Reset all learnable parameters in all components of the model.

class Compressor(model, attend, forward, pos_enc=None, bias=True, dropout=0.0, norm_first=True, norm_cls=<class 'torch.nn.modules.normalization.LayerNorm'>, *args, device='cpu', dtype=torch.float32, **kwargs)[source]

Bases: Trafo

Context-aware length-compression wrapper around sequence models.

The Compressor uses self- and cross-attention to compress incoming sequences of embeddings in length by a fixed factor of 2. The compressed sequences are then sent through the wrapped model and inflated again to the original length using cross-attention with U-net style residual connections to the compression stage.

Parameters:
  • model (Resettable) – The sequence model to wrap, typically an Encoder. It will be called with 3 arguments: the compressed sequences, a correspondingly compressed attention mask (or None), and a boolean flag indicating whether the attention mask is causal or not.

  • attend (Attn) – A suitably parameterized instance of a self-attention block, typically MultiheadedSelfAttention or GroupedQuerySelfAttention.

  • forward (Block) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a new() method to make fresh copies of itself,

    • processes tensors with dimensions (…, S, D),

    where S is the sequence length and D is the model dimension specified in the attention.

  • pos_enc (PosEnc, optional) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a context attribute specifying the maximum sequence length,

    • processes tensors with dimensions (…, S, D),

    where S is the sequence length and D is the model dimension specified in the layer. If given, it will be called on the input tensor first thing. Typically, this would be an instance of Sinusoidal or Learnable positional encodings. Defaults to an instance of IdentityBlock, which does nothing.

  • bias (bool, optional) – Whether to use a bias in the cross-attention components. Defaults to False.

  • dropout (float, optional) – Apply this amount of dropout to the sum of embeddings and positional encodings as well as to the outputs of each sub-layer. Defaults to 0.

  • norm_first (bool, optional) – Whether to normalize inputs to attentions and feed-forwards or the sum of respective inputs and outputs. Defaults to True.

  • norm_cls (type, optional) – Which type of norm to use between (sub-)layers. Must be one of torch.nn.LayerNorm (the default) or torch.nn.RMSNorm.

  • *args – Arguments used to initialize an instance of norm_cls.

  • device (str or torch.device, optional) – Torch device to first create the compressor on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to first create the compressor in. Defaults to torch.float.

  • **kwargs – Keyword arguments used to initialize an instance of norm_cls.

Important

If norm_first is True, the wrapped model will receive an un-normed input and may return one. If, however, norm_first is False, then the wrapped model will receive a normed input an must return one!

property context

Maximum context length permitted by the positional encodings.

property device

The device of all weights, biases, activations, etc. reside on.

property dtype

The dtype of all weights, biases, activations, and parameters.

forward(src, attn_mask=None, src_mask=None, is_causal=True)[source]

Forward pass through the compressor with optional masking.

Parameters:
  • src (Tensor) – Input sequence(s) of embeddings. Expected dimensions are (…, S, D), with S the sequence length and D the model dimension.

  • attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to. Defaults to None.

  • src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to. Defaults to None.

  • is_causal (bool, optional) –

    If set to True, inputs are masked with a causal S x S triangular matrix (as produced by generate_square_subsequent_mask) and both attn_mask and src_mask are ignored. Defaults to True.

Returns:

Transformed input with again dimensions (…, S, D).

Return type:

Tensor

Important

Boolean attention masks are not accepted!

property has_pos_enc

Whether positional encodings are applied.

static merge_masks(attn_mask, src_mask, is_causal)[source]

Utility method to merge attention and source masks if necessary.

Parameters:
  • attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to.

  • src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to.

  • is_causal (bool, optional) – If True, inputs are masked with a causal S x S triangular matrix and both attn_mask and src_mask are ignored.

Returns:

The merged masks, or None if none are provided or is_causal is True.

Return type:

Tensor or None

property mod_dim

The model dimension.

property n_heads

The number of attention heads used.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

Compressor

reset_parameters()[source]

Reset all learnable parameters in all components of the model.

Attention

class MultiheadedSelfAttention(mod_dim, n_heads=1, bias=False, dropout=0.1, pos_enc=None, device='cpu', dtype=torch.float32)[source]

Bases: Attn

Multi-headed self attention with optional (rotary) positional encodings.

Parameters:
  • mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.

  • n_heads (int, optional) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number. Defaults to 1.

  • bias (bool, optional) – Whether to add learnable bias vectors in the projections from input to query, key and value and the final out projection. Defaults to False.

  • dropout (float, optional) – Apply dropout to the attention weights with this probability during training. Defaults to 0.1

  • pos_enc (PosEnc, optional) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a new() method to make fresh copies of itself,

    • has a context attribute specifying the maximum sequence length,

    • processes tensors with dimensions (…, n_heads, S, head_dim),

    where S is the sequence length, and head_dim is the mod_dim divided by n_heads. If given, it will be called on queries and keys. Typically, this would be an instance of Rotary positional encodings. Defaults to an instance of IdentityBlock, which does nothing.

  • device (str or torch.device, optional) – Torch device to compute self attention on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to compute self attention in. Defaults to torch.float.

Raises:

ValueError – If n_heads does not integer divide mod_dim.

See also

Rotary, IdentityBlock

property context

Maximum context length of the positional encodings, if present.

property device

Device to compute self attention on.

property dtype

Dtype to compute self attention in.

forward(src, mask=None, is_causal=True)[source]

Forward pass through multi-headed self attention.

Parameters:
  • src (Tensor) – Input sequence(s) of dimensions (…, S, mod_dim), with sequence length S.

  • mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of True indicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults to None.

  • is_causal (bool, optional) – If set to True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default to True.

Returns:

The output has the same shape as the input.

Return type:

Tensor

Important

In adhering to the convention of the scaled_dot_product_attention, the meaning of True and False (attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the MultiheadAttention. Therefore, to stay compatible, use float masks!

property has_pos_enc

Whether a pos_enc module was provided at instantiation or not.

property head_dim

The dimension of each attention head.

property mod_dim

The model dimension.

property n_heads

The number of attention heads used.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

MultiheadedSelfAttention

reset_parameters()[source]

Reset the internal parameters of the projections and pos_enc.

class GroupedQuerySelfAttention(mod_dim, n_heads=1, q_factor=1, bias=False, dropout=0.1, pos_enc=None, device='cpu', dtype=torch.float32)[source]

Bases: Attn

Grouped-query attention with optional (rotary) positional encoding.

Parameters:
  • mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.

  • n_heads (int, optional) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number. Defaults to 1.

  • q_factor (int, optional) – Reduce the number of attention heads for keys and values by this factor compared to the n_heads used for queries. Must integer divide n_heads. Realizes standard multi-head-attention (MHA) with q_factor`=1, multi-query attention (MQA) with `q_factor`=`n_heads and grouped-query attention (GQA) otherwise. Defaults to 1 (i.e. standard MHA).

  • bias (bool, optional) – Whether to add learnable bias vectors in the projections from input to query, key and value and the final out projection. Defaults to False.

  • dropout (float, optional) – Apply dropout to the attention weights with this probability during training. Defaults to 0.1

  • pos_enc (PosEnc, optional) –

    PyTorch Module that

    • has a reset_parameters() method,

    • has a new() method to make fresh copies of itself,

    • has a context attribute specifying the maximum sequence length,

    • processes tensors with dimensions (…, n_heads, S, head_dim),

    where S is the sequence length, and head_dim is the mod_dim divided by n_heads. If given, it will be called on queries and keys. Typically, this would be an instance of Rotary positional encodings. Defaults to an instance of IdentityBlock, which does nothing.

  • device (str or torch.device, optional) – Torch device to compute self attention on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to compute self attention in. Defaults to torch.float.

Raises:

ValueError – If n_heads does not integer divide mod_dim. If n_groups does not integer divide n_heads.

See also

Rotary, IdentityBlock

property context

Maximum context length of the positional encodings, if present.

property device

Device to compute self attention on.

property dtype

Dtype to compute self attention in.

forward(src, mask=None, is_causal=True)[source]

Forward pass through grouped query attention.

Parameters:
  • src (Tensor) – Input sequence(s) of dimensions (…, S, mod_dim), with sequence length S.

  • mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of True indicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults to None.

  • is_causal (bool, optional) – If set to True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default to True.

Returns:

The output has the same shape as the input.

Return type:

Tensor

Important

In adhering to the convention of the scaled_dot_product_attention, the meaning of True and False (attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the MultiheadAttention. Therefore, to stay compatible, use float masks!

property has_pos_enc

Whether a pos_enc module was provided at instantiation or not.

property head_dim

The dimension of each attention head.

property kv_dim

n_kv_heads * head_dim.

Type:

Total dimension of keys (or values)

property mod_dim

The model dimension.

property n_heads

The number of attention heads used.

property n_kv_heads

The number of key/value heads.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

GroupedQuerySelfAttention

reset_parameters()[source]

Reset the internal parameters of the projections and pos_enc.

Positional encodings

class Sinusoidal(mod_dim, context, *_, device='cpu', dtype=torch.float32, **__)[source]

Bases: PosEnc

Sinusoidal positional encodings for transformer-based sequence models.

Parameters:
  • mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.

  • context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.

  • device (str or torch.device, optional) – Torch device to first create the sinusoidal positional encodings on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to first create the sinusoidal positional encodings in. Defaults to torch.float.

property context

The maximum sequence length.

property device

Device that the sinusoidal positional encodings reside on.

property dtype

Dtype of the sinusoidal positional encodings.

forward(src, offset=0)[source]

Add sinusoidal positional encodings to a sequence of embeddings.

Parameters:
  • src (Tensor) – Input sequence(s). Must be of dimensions (…, S, mod_dim), where the sequence length S must not exceed context.

  • offset (int, optional) – Offset to add to the index of the positional encodings. Defaults to 0.

Returns:

The input sequence(s) with sinusoidal positional encodings added.

Return type:

Tensor

property mod_dim

The model dimension.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

Sinusoidal

reset_parameters()[source]

Does nothing because there are no internal parameters to reset.

class Rotary(mod_dim, context, n_heads, device='cpu', dtype=torch.float32, **_)[source]

Bases: PosEnc

Rotary positional encodings for multi-head attention in sequence models.

Parameters:
  • mod_dim (int) – The model dimension. Each vector in the original sequence is expected to be of that dimension.

  • context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.

  • n_heads (int) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number.

  • device (str or torch.device, optional) – Torch device to first create the rotary positional encodings on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype to first create the rotary positional encodings in. Defaults to torch.float.

Raises:

ValueError – If n_heads does not integer divide mod_dim or if the result is not an even number.

property context

The maximum sequence length.

property device

Device that the rotary positional encodings reside on.

property dtype

Dtype of the rotary positional encodings.

forward(src, offset=0)[source]

Apply rotary positional encodings across all heads of the input.

Parameters:
  • src (Tensor) – Input sequence(s) for all heads. Must be of dimensions (…, n_heads, S, head_dim), where the sequence length S must not exceed context and head_dim is the mod_dim divided by n_heads.

  • offset (int, optional) – Unused. Only for purposes of API compatibility.

Returns:

The input sequence(s) with rotary positional encodings applied to all heads.

Return type:

Tensor

property head_dim

The dimension of each attention head.

property mod_dim

The model dimension.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

Rotary

reset_parameters()[source]

Does nothing because there are no internal parameters to reset.

class Learnable(mod_dim, context, *_, device='cpu', dtype=torch.float32, **__)[source]

Bases: PosEnc

Learnable positional encodings for transformer-based sequence models.

Parameters:
  • mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.

  • context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.

  • device (str or torch.device, optional) – Torch device to create the learnable positional encodings on. Defaults to “cpu”.

  • dtype (torch.dtype, optional) – Torch dtype of the learnable positional encodings. Defaults to torch.float.

Warning

Make sure that the context reflects the maximum length of the sequences that your model sees at training time. In contrast to other types of positional encodings, which can reasonably be expected to generalize well beyond that during inference, positions that have never been encountered during training cannot be encoded at all with Learnable.

See also

Sinusoidal, Rotary

property context

The maximum sequence length.

property device

Device that the learnable positional encodings reside on.

property dtype

Dtype of the learnable positional encodings.

forward(src, offset=0)[source]

Add learnable positional encodings to a sequence of embeddings.

Parameters:
  • src (Tensor) – Input sequence(s). Must be of dimensions (…, S, mod_dim), where the sequence length S must not exceed context.

  • offset (int, optional) – Offset to add to the index of the positional encodings. Defaults to 0.

Returns:

The input sequence(s) with positional encodings added.

Return type:

Tensor

property mod_dim

The model dimension.

new()[source]

Return a fresh, new instance with exactly the same parameters.

Returns:

A fresh, new instance of itself.

Return type:

Learnable

reset_parameters()[source]

Re-initialize the learnable positional encodings.

Base classes

class PosEnc(*_, **__)[source]

Bases: Block

abstract property context

The maximum sequence length.

class Trafo(*_, **__)[source]

Bases: PosEnc

abstract property has_pos_enc

Whether positional encodings are applied.

class Attn(*_, **__)[source]

Bases: Trafo

abstract property n_heads

The number of attention heads used.