Source code for keras_mml.layers.transformer.transformer

"""
Implements a matmul-less transformer block.
"""

from typing import Tuple

import keras
import numpy as np
from jaxtyping import Float

from keras_mml.layers.activations import SwiGLUMML
from keras_mml.layers.normalizations.rms_norm import RMSNorm
from keras_mml.layers.transformer.attention import AttentionMML


[docs] @keras.saving.register_keras_serializable(package="keras_mml") class TransformerBlockMML(keras.Layer): """ Transformer block layer that is mostly without matrix multiplications. The core flow of the transformer block follows the |AttentionPaper|_ paper, while referencing the Keras example |KerasTransformer|_ for its high-level implementation. However, we use the custom :py:class:`~keras_mml.layers.transformer.AttentionMML` class for the attention mechanism and :py:class:`~keras_mml.layers.activations.SwiGLUMML` for the feed-forward network (FFN) part. Attributes: embedding_dim: Dimension of the embeddings. ffn_dim: Dimension of the intermediate (i.e., hidden) layer of the feed-forward network. num_heads: Number of heads to use for multi-headed attention. fully_mml: Whether to use full matmul-less layers in the attention mechanism. rate: Dropout rate to apply for the attention mechanism and the feed-forward network. .. |AttentionPaper| replace:: *Attention Is All You Need* .. _AttentionPaper: https://arxiv.org/pdf/1706.03762v7 .. |KerasTransformer| replace:: *Text classification with Transformer* .. _KerasTransformer: https://keras.io/examples/nlp/text_classification_with_transformer/ """
[docs] def __init__( self, embedding_dim: int, ffn_dim: int, num_heads: int, fully_mml: bool = True, rate: float = 0.1, **kwargs, ): """ Initializes a new instance of the layer. Args: embedding_dim: Dimension of the embeddings. ffn_dim: Dimension of the intermediate (i.e., hidden) layer of the feed-forward network. num_heads: Number of heads to use for multi-headed attention. fully_mml: Whether to use full matmul-less layers in the attention mechanism. rate: Dropout rate to apply for the attention mechanism and the feed-forward network. **kwargs: Keyword arguments for :py:class:`keras.Layer`. Raises: ValueError: If the embedding dimension is not a positive integer. ValueError: If the dimension of the intermediate layer of the feed-forward network is not a positive integer. ValueError: If the number of heads is not a positive integer. ValueError: If the embedding dimension is not divisible by the number of heads. """ if embedding_dim <= 0: raise ValueError( f"Received an invalid value for embedding dimension, expected a positive integer, got {embedding_dim}" ) if ffn_dim <= 0: raise ValueError( "Received an invalid value for the feed forward network dimension, " f"expected a positive integer, got {ffn_dim}" ) if num_heads <= 0: raise ValueError( f"Received an invalid value for the number of heads, expected a positive integer, got {num_heads}" ) if embedding_dim % num_heads != 0: raise ValueError( "Embedding dimension must be divisible by the number of heads. " f"Got embedding dimension of {embedding_dim} but wanted to use {num_heads} heads." ) super().__init__(**kwargs) self.input_spec = keras.layers.InputSpec(ndim=3) # Main attributes self.embedding_dim = embedding_dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.fully_mml = fully_mml self.rate = rate # Hidden weights/layers self._attention_layer = AttentionMML(num_heads, embedding_dim, fully_mml=fully_mml) self._attention_dropout = keras.layers.Dropout(rate) self._attention_norm = RMSNorm() self._ffn_layer = SwiGLUMML(embedding_dim, intermediate_size=ffn_dim) self._ffn_dropout = keras.layers.Dropout(rate) self._ffn_norm = RMSNorm()
[docs] def build(self, input_shape: Tuple[int, int, int]): """ Build the layer. Args: input_shape: Shape of the input. """ self._attention_layer.build(input_shape) intermediate_shape = self._attention_layer.compute_output_shape(input_shape) self._attention_dropout.build(intermediate_shape) self._attention_norm.build(intermediate_shape) self._ffn_layer.build(intermediate_shape) intermediate_shape = self._ffn_layer.compute_output_shape(intermediate_shape) self._ffn_dropout.build(intermediate_shape) self._ffn_norm.build(intermediate_shape) self.built = True
[docs] def call( self, inputs: Float[np.ndarray, "batch_size sequence_length features"] ) -> Float[np.ndarray, "batch_size sequence_length embedding_dim"]: """ Calling method of the layer. Args: inputs: Inputs into the layer. Returns: Transformed inputs. """ attention_output = self._attention_layer(inputs) attention_output = self._attention_dropout(attention_output) attention_output = self._attention_norm(inputs + attention_output) ffn_output = self._ffn_layer(attention_output) ffn_output = self._ffn_dropout(ffn_output) ffn_output = self._ffn_norm(attention_output + ffn_output) return ffn_output
[docs] def compute_output_shape(self, input_shape: Tuple[int, int, int]) -> Tuple[int, int, int]: """ Computes the output shape of the layer. Args: input_shape: Shape of the input into the layer. Returns: Shape of the output. """ input_shape = list(input_shape) input_shape[-1] = self.embedding_dim return tuple(input_shape)