"""
Implements embedding layers.
"""
from typing import Tuple
import keras
import numpy as np
from jaxtyping import Float
from keras import ops
from keras_mml.layers.core.dense import DenseMML
[docs]
@keras.saving.register_keras_serializable(package="keras_mml")
class TokenEmbedding(keras.layers.Layer):
"""
Turns positive integers (indices) into vectors of fixed size.
For example, ``[[1, 2], [3, 4], [5, 6]]``, which could be interpreted as 3 sentences with 2
words each, could be embedded as ``[[[0.1, 0.2, 0.3], [0.3, 0.4, 0.5]], [[1.1, 1.2, 1.3],
[1.3, 1.4, 1.5]], [[2.1, 2.2, 2.3], [2.3, 2.4, 2.5]]]``, which has shape ``(3, 2, 3)`` and can
be interpreted as 3 sentences with 2 words each with an embedding dimension of 3.
This layer could optionally include position information in the embeddings by enabling the
:py:attr:`with_positions` attribute.
Attributes:
max_len: Maximum length of a sentence.
vocab_size: Size of the vocabulary. Typically this is one more than the maximum integer
index.
embedding_dim: Embedding dimension.
with_positions: Whether to include position information in the embeddings.
"""
[docs]
def __init__(self, max_len: int, vocab_size: int, embedding_dim: int, with_positions: bool = False, **kwargs):
"""
Initializes a new instance of the layer.
Args:
max_len: Maximum length of a sentence.
vocab_size: Size of the vocabulary. Typically this is one more than the maximum integer
index.
embedding_dim: Embedding dimension.
with_positions: Whether to include position information in the embeddings.
**kwargs: Keyword arguments for :py:class:`keras.Layer`.
Raises:
ValueError: If the maximum sentence length is not a positive integer.
ValueError: If the vocabulary size is not a positive integer.
ValueError: If the embedding dimension is not a positive integer.
"""
if max_len <= 0:
raise ValueError(
"Received an invalid value for the maximum sentence length. "
f"Expected a positive integer, but got {max_len} instead."
)
if vocab_size <= 0:
raise ValueError(
"Received an invalid value for the vocabulary size. "
f"Expected a positive integer, but got {vocab_size} instead."
)
if embedding_dim <= 0:
raise ValueError(
"Received an invalid value for the embedding dimension. "
f"Expected a positive integer, but got {embedding_dim} instead."
)
super().__init__(**kwargs)
self.input_spec = keras.layers.InputSpec(ndim=2)
# Main attributes
self.max_len = max_len
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.with_positions = with_positions
# Hidden weights/layers
self._token_embedding = keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)
if with_positions:
self._pos_embedding = keras.layers.Embedding(input_dim=max_len, output_dim=embedding_dim)
else:
self._pos_embedding = None
[docs]
def build(self, input_shape: Tuple[int, int]):
"""
Build the layer.
Args:
input_shape: Shape of the input.
"""
self._token_embedding.build(input_shape)
if self._pos_embedding is not None:
self._pos_embedding.build(input_shape)
self.built = True
[docs]
def call(
self, inputs: Float[np.ndarray, "batch_size sequence_len"]
) -> Float[np.ndarray, "batch_size sequence_len embedding_dim"]:
"""
Calling method of the layer.
Args:
inputs: Inputs into the layer.
Returns:
Transformed inputs.
"""
tokens = self._token_embedding(inputs)
if self._pos_embedding is None:
return tokens
max_len = ops.shape(inputs)[-1]
positions = ops.arange(start=0, stop=max_len, step=1)
positions = self._pos_embedding(positions)
return tokens + positions
[docs]
@keras.saving.register_keras_serializable(package="keras_mml")
class PatchEmbedding(keras.layers.Layer):
"""
Turns image patches into vectors of fixed size.
The image patches should have come from the :py:class:`~keras_mml.layers.misc.Patches` layer.
This layer could optionally include position information in the embeddings by enabling the
:py:attr:`with_positions` attribute.
Attributes:
num_patches: Number of patches in each image.
embedding_dim: Embedding dimension.
use_mml: Whether to use a matmul-less projection to embed the patches.
with_positions: Whether to include position information in the embeddings.
"""
[docs]
def __init__(
self, num_patches: int, embedding_dim: int, use_mml: bool = True, with_positions: bool = False, **kwargs
):
"""
Initializes a new instance of the layer.
Args:
num_patches: Number of patches in each image.
embedding_dim: Embedding dimension.
use_mml: Whether to use a matmul-less projection to embed the patches.
with_positions: Whether to include position information in the embeddings.
**kwargs: Keyword arguments for :py:class:`keras.Layer`.
Raises:
ValueError: If the number of patches is not a positive integer.
ValueError: If the embedding dimension is not a positive integer.
"""
if num_patches <= 0:
raise ValueError(f"Invalid number of patches, expected a positive integer, got {num_patches}")
if embedding_dim <= 0:
raise ValueError(f"Invalid embedding dimension, expected a positive integer, got {embedding_dim}")
super().__init__(**kwargs)
self.input_spec = keras.layers.InputSpec(ndim=3)
# Main attributes
self.num_patches = num_patches
self.embedding_dim = embedding_dim
self.use_mml = use_mml
self.with_positions = with_positions
# Hidden weights/layers
if self.use_mml:
self._projection = DenseMML(embedding_dim)
else:
self._projection = keras.layers.Dense(embedding_dim)
if with_positions:
self._pos_embedding = keras.layers.Embedding(input_dim=num_patches, output_dim=embedding_dim)
else:
self._pos_embedding = None
[docs]
def build(self, input_shape: Tuple[int, int, int]):
"""
Build the layer.
Args:
input_shape: Shape of the input.
"""
self._projection.build(input_shape)
if self._pos_embedding is not None:
self._pos_embedding.build(input_shape)
self.built = True
[docs]
def call(
self, inputs: Float[np.ndarray, "batch_size patch_count patch_dim"]
) -> Float[np.ndarray, "batch_size patch_count embedding_dim"]:
"""
Calling method of the layer.
Args:
inputs: Inputs into the layer.
Returns:
Transformed inputs.
"""
projected_patches = self._projection(inputs)
if self._pos_embedding is None:
return projected_patches
positions = ops.expand_dims(ops.arange(start=0, stop=self.num_patches, step=1), axis=0)
positions = self._pos_embedding(positions)
return projected_patches + positions