transformer
Custom components for modern transformer architectures.
These include, but are not limited to, sinusoidal anr rotary position encodings (RoPE), flexible grouped-query self-attention, and full-blown encoder-only transformer using these components.
- class EncoderLayer(attention, feed_forward, pos_enc=None, dropout=0.1, norm_first=True, norm_cls=<class 'torch.nn.modules.normalization.LayerNorm'>, *args, device='cpu', dtype=torch.float32, **kwargs)[source]
Bases:
TrafoEncoder layer (i.e., self-attention only) to use in a transformer.
- Parameters:
attention (Attn) – A suitably parameterized instance of a self-attention block, typically
MultiheadedSelfAttentionorGroupedQuerySelfAttention.feed_forward (Block) –
PyTorch
Modulethathas a
reset_parameters()method,has a
new()method to make fresh copies of itself,processes tensors with dimensions (…, S, D),
where S is the sequence length and D is the model dimension specified in the attention.
pos_enc (PosEnc, optional) –
PyTorch
Modulethathas a
reset_parameters()method,has a
new()method to make fresh copies of itself,has a
contextattribute specifying the maximum sequence length,processes tensors with dimensions (…, S, D),
where S is the sequence length and D is the model dimension specified in the attention. If given, it will be called on the input tensor first thing. Typically, this would be an instance of
SinusoidalorLearnablepositional encodings. Defaults to an instance ofIdentityBlock, which does nothing.dropout (float, optional) – Fraction of dropout to apply after self-attention and feed-forward. Defaults to 0.1
norm_first (bool, optional) – Whether to normalize inputs to attention and feed-forward or the sum of respective inputs and outputs. Defaults to
True.norm_cls (type, optional) – Which type of norm to use between (sub-)layers. Must be one of
torch.nn.LayerNorm(the default) ortorch.nn.RMSNorm.*args – Arguments used to initialize an instance of norm_cls.
device (str or torch.device, optional) – Torch device to first create the encoder layer on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to first create the layer in. Defaults to
torch.float.**kwargs – Keyword arguments used to initialize an instance of norm_cls.
- property context
Maximum context length given by the positional encodings.
- property device
The device all weights, biases, activations, etc. reside on.
- property dtype
The dtype of all weights, biases, activations, and parameters.
- forward(src, mask=None, is_causal=True)[source]
Forward pass of one encoder layer (i.e., with self.attention only).
- Parameters:
src (Tensor) – Input sequence(s) of dimensions (…, S, D), with sequence length S and model dimension D.
mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of
Trueindicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults toNone.is_causal (bool, optional) – If set to
True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default toTrue.
- Returns:
The output has the same shape as the input.
- Return type:
Tensor
Important
In adhering to the convention of the scaled_dot_product_attention, the meaning of
TrueandFalse(attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the Transformer. Therefore, to stay compatible, use float masks!
- property has_pos_enc
Whether positional encodings are applied.
- property mod_dim
The model dimension.
- class Encoder(layer, n_layers=1, pos_enc=None, dropout=0.0, device='cpu', dtype=torch.float32)[source]
Bases:
TrafoFlexible transformer encoder.
- Parameters:
layer (EncoderLayer) – A suitably parameterized instance of
EncoderLayer.n_layers (int, optional) – How often the layer is repeated in the transformer stack. Must be at least 1, the default.
pos_enc (PosEnc, optional) –
PyTorch
Modulethathas a
reset_parameters()method,has a
contextattribute specifying the maximum sequence length,processes tensors with dimensions (…, S, D),
where S is the sequence length and D is the model dimension specified in the layer. If given, it will be called on the input tensor first thing. Typically, this would be an instance of
SinusoidalorLearnablepositional encodings. Defaults to an instance ofIdentityBlock, which does nothing.dropout (float, optional) – Apply dropout to the sum of token embeddings and positional encodings with this probability during training. Defaults to 0.
device (str or torch.device, optional) – Torch device to first create the transformer on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to first create the transformer encoder stack in. Defaults to
torch.float.
- Raises:
ValueError – If n_layers is less than 1.
Note
If the layer sets norm_first to
True, no norm is applied to the final output of theEncoder. If a trailing norm is desired, it should be applied externally, after this module.See also
- property context
Maximum context length permitted by the positional encodings.
- property device
The device of all weights, biases, activations, etc. reside on.
- property dtype
The dtype of all weights, biases, activations, and parameters.
- forward(src, attn_mask=None, src_mask=None, is_causal=True)[source]
Forward pass through the transformer encoder with optional masking.
- Parameters:
src (Tensor) – Input sequence(s) of token embedding. Expected dimensions are (…, S, D), with S the sequence length and D the model dimension.
attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to. Defaults to
None.src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to. Defaults to
None.is_causal (bool, optional) – If set to
True, inputs are masked with a causal S x S triangular matrix (as produced by generate_square_subsequent_mask) and both attn_mask and src_mask are ignored. Defaults toTrue.
- Returns:
Transformed input with again dimensions (…, S, D).
- Return type:
Tensor
Important
Boolean attention masks are not accepted!
- property has_pos_enc
Whether positional encodings are applied.
- static merge_masks(attn_mask, src_mask, is_causal)[source]
Utility method to merge attention and source masks if necessary.
- Parameters:
attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to.
src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to.
is_causal (bool, optional) – If
True, inputs are masked with a causal S x S triangular matrix and both attn_mask and src_mask are ignored.
- Returns:
The merged masks, or
Noneif none are provided or is_causal isTrue.- Return type:
Tensor or None
- property mod_dim
The model dimension.
- class Compressor(model, attend, forward, pos_enc=None, bias=True, dropout=0.0, norm_first=True, norm_cls=<class 'torch.nn.modules.normalization.LayerNorm'>, *args, device='cpu', dtype=torch.float32, **kwargs)[source]
Bases:
TrafoContext-aware length-compression wrapper around sequence models.
The
Compressoruses self- and cross-attention to compress incoming sequences of embeddings in length by a fixed factor of 2. The compressed sequences are then sent through the wrapped model and inflated again to the original length using cross-attention with U-net style residual connections to the compression stage.- Parameters:
model (Resettable) – The sequence model to wrap, typically an
Encoder. It will be called with 3 arguments: the compressed sequences, a correspondingly compressed attention mask (orNone), and a boolean flag indicating whether the attention mask is causal or not.attend (Attn) – A suitably parameterized instance of a self-attention block, typically
MultiheadedSelfAttentionorGroupedQuerySelfAttention.forward (Block) –
PyTorch
Modulethathas a
reset_parameters()method,has a
new()method to make fresh copies of itself,processes tensors with dimensions (…, S, D),
where S is the sequence length and D is the model dimension specified in the attention.
pos_enc (PosEnc, optional) –
PyTorch
Modulethathas a
reset_parameters()method,has a
contextattribute specifying the maximum sequence length,processes tensors with dimensions (…, S, D),
where S is the sequence length and D is the model dimension specified in the layer. If given, it will be called on the input tensor first thing. Typically, this would be an instance of
SinusoidalorLearnablepositional encodings. Defaults to an instance ofIdentityBlock, which does nothing.bias (bool, optional) – Whether to use a bias in the cross-attention components. Defaults to
False.dropout (float, optional) – Apply this amount of dropout to the sum of embeddings and positional encodings as well as to the outputs of each sub-layer. Defaults to 0.
norm_first (bool, optional) – Whether to normalize inputs to attentions and feed-forwards or the sum of respective inputs and outputs. Defaults to
True.norm_cls (type, optional) – Which type of norm to use between (sub-)layers. Must be one of
torch.nn.LayerNorm(the default) ortorch.nn.RMSNorm.*args – Arguments used to initialize an instance of norm_cls.
device (str or torch.device, optional) – Torch device to first create the compressor on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to first create the compressor in. Defaults to
torch.float.**kwargs – Keyword arguments used to initialize an instance of norm_cls.
Important
If norm_first is
True, the wrapped model will receive an un-normed input and may return one. If, however, norm_first isFalse, then the wrapped model will receive a normed input an must return one!- property context
Maximum context length permitted by the positional encodings.
- property device
The device of all weights, biases, activations, etc. reside on.
- property dtype
The dtype of all weights, biases, activations, and parameters.
- forward(src, attn_mask=None, src_mask=None, is_causal=True)[source]
Forward pass through the compressor with optional masking.
- Parameters:
src (Tensor) – Input sequence(s) of embeddings. Expected dimensions are (…, S, D), with S the sequence length and D the model dimension.
attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to. Defaults to
None.src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to. Defaults to
None.is_causal (bool, optional) –
If set to
True, inputs are masked with a causal S x S triangular matrix (as produced by generate_square_subsequent_mask) and both attn_mask and src_mask are ignored. Defaults toTrue.
- Returns:
Transformed input with again dimensions (…, S, D).
- Return type:
Tensor
Important
Boolean attention masks are not accepted!
- property has_pos_enc
Whether positional encodings are applied.
- static merge_masks(attn_mask, src_mask, is_causal)[source]
Utility method to merge attention and source masks if necessary.
- Parameters:
attn_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of the attention weights (…, S, S) to be added to the product of queries and keys, before taking the softmax. A value of 0.0 (resulting in unchanged attention weights) indicates that an element should be attended to and a value of “-inf” (resulting in a zero attention weight) that it should not be attended to.
src_mask (Tensor, optional) – Floating-point attention mask with a shape broadcastable to the shape of src (…, S). A value of 0.0 indicates that an element should be attended to and a value of “-inf” that it should not be attended to.
is_causal (bool, optional) – If
True, inputs are masked with a causal S x S triangular matrix and both attn_mask and src_mask are ignored.
- Returns:
The merged masks, or
Noneif none are provided or is_causal isTrue.- Return type:
Tensor or None
- property mod_dim
The model dimension.
- property n_heads
The number of attention heads used.
Attention
- class MultiheadedSelfAttention(mod_dim, n_heads=1, bias=False, dropout=0.1, pos_enc=None, device='cpu', dtype=torch.float32)[source]
Bases:
AttnMulti-headed self attention with optional (rotary) positional encodings.
- Parameters:
mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.
n_heads (int, optional) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number. Defaults to 1.
bias (bool, optional) – Whether to add learnable bias vectors in the projections from input to query, key and value and the final out projection. Defaults to
False.dropout (float, optional) – Apply dropout to the attention weights with this probability during training. Defaults to 0.1
pos_enc (PosEnc, optional) –
PyTorch
Modulethathas a
reset_parameters()method,has a
new()method to make fresh copies of itself,has a
contextattribute specifying the maximum sequence length,processes tensors with dimensions (…, n_heads, S, head_dim),
where S is the sequence length, and head_dim is the mod_dim divided by n_heads. If given, it will be called on queries and keys. Typically, this would be an instance of
Rotarypositional encodings. Defaults to an instance ofIdentityBlock, which does nothing.device (str or torch.device, optional) – Torch device to compute self attention on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to compute self attention in. Defaults to
torch.float.
- Raises:
ValueError – If n_heads does not integer divide mod_dim.
See also
- property context
Maximum context length of the positional encodings, if present.
- property device
Device to compute self attention on.
- property dtype
Dtype to compute self attention in.
- forward(src, mask=None, is_causal=True)[source]
Forward pass through multi-headed self attention.
- Parameters:
src (Tensor) – Input sequence(s) of dimensions (…, S, mod_dim), with sequence length S.
mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of
Trueindicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults toNone.is_causal (bool, optional) – If set to
True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default toTrue.
- Returns:
The output has the same shape as the input.
- Return type:
Tensor
Important
In adhering to the convention of the scaled_dot_product_attention, the meaning of
TrueandFalse(attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the MultiheadAttention. Therefore, to stay compatible, use float masks!
- property has_pos_enc
Whether a pos_enc module was provided at instantiation or not.
- property head_dim
The dimension of each attention head.
- property mod_dim
The model dimension.
- property n_heads
The number of attention heads used.
- class GroupedQuerySelfAttention(mod_dim, n_heads=1, q_factor=1, bias=False, dropout=0.1, pos_enc=None, device='cpu', dtype=torch.float32)[source]
Bases:
AttnGrouped-query attention with optional (rotary) positional encoding.
- Parameters:
mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.
n_heads (int, optional) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number. Defaults to 1.
q_factor (int, optional) – Reduce the number of attention heads for keys and values by this factor compared to the n_heads used for queries. Must integer divide n_heads. Realizes standard multi-head-attention (MHA) with q_factor`=1, multi-query attention (MQA) with `q_factor`=`n_heads and grouped-query attention (GQA) otherwise. Defaults to 1 (i.e. standard MHA).
bias (bool, optional) – Whether to add learnable bias vectors in the projections from input to query, key and value and the final out projection. Defaults to
False.dropout (float, optional) – Apply dropout to the attention weights with this probability during training. Defaults to 0.1
pos_enc (PosEnc, optional) –
PyTorch
Modulethathas a
reset_parameters()method,has a
new()method to make fresh copies of itself,has a
contextattribute specifying the maximum sequence length,processes tensors with dimensions (…, n_heads, S, head_dim),
where S is the sequence length, and head_dim is the mod_dim divided by n_heads. If given, it will be called on queries and keys. Typically, this would be an instance of
Rotarypositional encodings. Defaults to an instance ofIdentityBlock, which does nothing.device (str or torch.device, optional) – Torch device to compute self attention on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to compute self attention in. Defaults to
torch.float.
- Raises:
ValueError – If n_heads does not integer divide mod_dim. If n_groups does not integer divide n_heads.
See also
- property context
Maximum context length of the positional encodings, if present.
- property device
Device to compute self attention on.
- property dtype
Dtype to compute self attention in.
- forward(src, mask=None, is_causal=True)[source]
Forward pass through grouped query attention.
- Parameters:
src (Tensor) – Input sequence(s) of dimensions (…, S, mod_dim), with sequence length S.
mask (Tensor, optional) – Attention mask with a shape broadcastable to the shape of the attention weights (…, S, S). Two types of masks are supported: A boolean mask where a value of
Trueindicates that the element should take part in attention or a float mask of the same dtype as src that is added to the product of queries and keys, before taking the softmax. In the latter case, a value of 0.0 (resulting in unchanged attention weights) indicates that an element should take part in the attention and a value of “-inf” (resulting in a zero attention weight) that it should not. Defaults toNone.is_causal (bool, optional) – If set to
True, inputs are masked with a S x S lower triangular matrix and mask is ignored. Default toTrue.
- Returns:
The output has the same shape as the input.
- Return type:
Tensor
Important
In adhering to the convention of the scaled_dot_product_attention, the meaning of
TrueandFalse(attend to and not attend to, respectively) in boolean attention masks is exactly the opposite of what it means in the MultiheadAttention. Therefore, to stay compatible, use float masks!
- property has_pos_enc
Whether a pos_enc module was provided at instantiation or not.
- property head_dim
The dimension of each attention head.
- property kv_dim
n_kv_heads * head_dim.
- Type:
Total dimension of keys (or values)
- property mod_dim
The model dimension.
- property n_heads
The number of attention heads used.
- property n_kv_heads
The number of key/value heads.
Positional encodings
- class Sinusoidal(mod_dim, context, *_, device='cpu', dtype=torch.float32, **__)[source]
Bases:
PosEncSinusoidal positional encodings for transformer-based sequence models.
- Parameters:
mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.
context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.
device (str or torch.device, optional) – Torch device to first create the sinusoidal positional encodings on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to first create the sinusoidal positional encodings in. Defaults to
torch.float.
- property context
The maximum sequence length.
- property device
Device that the sinusoidal positional encodings reside on.
- property dtype
Dtype of the sinusoidal positional encodings.
- forward(src, offset=0)[source]
Add sinusoidal positional encodings to a sequence of embeddings.
- Parameters:
src (Tensor) – Input sequence(s). Must be of dimensions (…, S, mod_dim), where the sequence length S must not exceed context.
offset (int, optional) – Offset to add to the index of the positional encodings. Defaults to 0.
- Returns:
The input sequence(s) with sinusoidal positional encodings added.
- Return type:
Tensor
- property mod_dim
The model dimension.
- class Rotary(mod_dim, context, n_heads, device='cpu', dtype=torch.float32, **_)[source]
Bases:
PosEncRotary positional encodings for multi-head attention in sequence models.
- Parameters:
mod_dim (int) – The model dimension. Each vector in the original sequence is expected to be of that dimension.
context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.
n_heads (int) – The number of attention heads. Must integer divide mod_dim and the result must still be and even number.
device (str or torch.device, optional) – Torch device to first create the rotary positional encodings on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype to first create the rotary positional encodings in. Defaults to
torch.float.
- Raises:
ValueError – If n_heads does not integer divide mod_dim or if the result is not an even number.
- property context
The maximum sequence length.
- property device
Device that the rotary positional encodings reside on.
- property dtype
Dtype of the rotary positional encodings.
- forward(src, offset=0)[source]
Apply rotary positional encodings across all heads of the input.
- Parameters:
src (Tensor) – Input sequence(s) for all heads. Must be of dimensions (…, n_heads, S, head_dim), where the sequence length S must not exceed context and head_dim is the mod_dim divided by n_heads.
offset (int, optional) – Unused. Only for purposes of API compatibility.
- Returns:
The input sequence(s) with rotary positional encodings applied to all heads.
- Return type:
Tensor
- property head_dim
The dimension of each attention head.
- property mod_dim
The model dimension.
- class Learnable(mod_dim, context, *_, device='cpu', dtype=torch.float32, **__)[source]
Bases:
PosEncLearnable positional encodings for transformer-based sequence models.
- Parameters:
mod_dim (int) – The model dimension. Inputs are expected to be of that size in their last dimension.
context (int) – The maximum sequence length that can be processed. Inputs are expected to not exceed this size in their next-to-last dimension.
device (str or torch.device, optional) – Torch device to create the learnable positional encodings on. Defaults to “cpu”.
dtype (torch.dtype, optional) – Torch dtype of the learnable positional encodings. Defaults to
torch.float.
Warning
Make sure that the context reflects the maximum length of the sequences that your model sees at training time. In contrast to other types of positional encodings, which can reasonably be expected to generalize well beyond that during inference, positions that have never been encountered during training cannot be encoded at all with
Learnable.See also
- property context
The maximum sequence length.
- property device
Device that the learnable positional encodings reside on.
- property dtype
Dtype of the learnable positional encodings.
- forward(src, offset=0)[source]
Add learnable positional encodings to a sequence of embeddings.
- Parameters:
src (Tensor) – Input sequence(s). Must be of dimensions (…, S, mod_dim), where the sequence length S must not exceed context.
offset (int, optional) – Offset to add to the index of the positional encodings. Defaults to 0.
- Returns:
The input sequence(s) with positional encodings added.
- Return type:
Tensor
- property mod_dim
The model dimension.
Base classes
- class PosEnc(*_, **__)[source]
Bases:
Block- abstract property context
The maximum sequence length.