"""
Implements a matmul-less Dense layer.
"""
from typing import Dict, Optional, Tuple
import keras
import numpy as np
from einops import asnumpy
from jaxtyping import Float
from keras import activations, constraints, initializers, ops, regularizers
from keras_mml.layers.core._dense_impl import BackendDenseMML
from keras_mml.layers.normalizations.rms_norm import RMSNorm
from keras_mml.utils.array.encoding import decode_ternary_array, encode_ternary_array
EPSILON = 1e-5
HUGE = 1e9
[docs]
@keras.saving.register_keras_serializable(package="keras_mml")
class DenseMML(BackendDenseMML):
"""
Dense layer without matrix multiplications.
The core of the layer is the ``BitLinear`` layer described in |1.58 Bit LLMs|_. It uses ternary
quantization to reduce matrix multiplication operations to simple addition and subtraction.
This implementation differs from ``BitLinear`` by allowing an activation function to be
specified. More precisely, :py:class:`~DenseMML` implements the operation
.. math::
\\mathbf{y} = \\sigma\\left(\\mathbf{x}\\mathbf{W}^\\intercal + \\mathbf{b}\\right)
where :math:`\\mathbf{x}` is the quantized input vector, :math:`\\mathbf{W}` is the quantized
weights matrix (i.e., the *kernel* matrix), :math:`\\mathbf{b}` is the bias vector, and
:math:`\\sigma` is the element-wise activation function.
.. IMPORTANT::
See the :ref:`pitfalls <explanations/dense:pitfalls>` when using this layer.
.. NOTE::
If the input to the layer (say :math:`\\mathbf{x}`) has a rank greater than 2, then this
computes the dot product of :math:`\\mathbf{x}` and :math:`\\mathbf{W}` along the last axis
of the :math:`\\mathbf{x}` and axis ``0`` of :math:`\\mathbf{W}`.
For example, suppose :math:`\\mathbf{x}` has shape ``(batch_size, d0, d1)``. Then
:math:`\\mathbf{W}` is created to have shape ``(d1, units)`` and it operates along axis ``2``
of :math:`\\mathbf{x}` on every sub-tensor of shape ``(1, 1, d1)`` (there are
``batch_size * d0`` such sub-tensors). The output in this case will have shape
``(batch_size, d0, units)``.
.. WARNING::
Once a model that uses this layer is loaded from a file, it **cannot** be retrained.
Attributes:
units: Dimensionality of the output space.
use_bias: Whether the layer uses a bias vector.
kernel_initializer: Initializer for the kernel matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to the kernel matrix.
bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
.. |1.58 Bit LLMs| replace:: *The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits*
.. _1.58 Bit LLMs: https://arxiv.org/pdf/2402.17764v1
"""
[docs]
def __init__(
self,
units: int,
activation: Optional[str] = None,
use_bias: bool = True,
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
kernel_regularizer: Optional[str] = None,
bias_regularizer: Optional[str] = None,
activity_regularizer: Optional[str] = None,
kernel_constraint: Optional[str] = None,
bias_constraint: Optional[str] = None,
**kwargs,
):
"""
Initializes a new :py:class:`~DenseMML` layer.
Args:
units: Dimensionality of the output space.
activation: Activation function to use. If you don't specify anything, no activation is
applied (i.e. "linear" activation: :math:`\\sigma(\\mathbf{x}) = \\mathbf{x}`).
use_bias: Whether the layer uses a bias vector.
kernel_initializer: Initializer for the kernel matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to the kernel matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to the output of the layer (i.e., its
activation).
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
**kwargs: Keyword arguments for :py:class:`keras.Layer`.
Raises:
ValueError: If the units provided is not a positive integer.
"""
if units <= 0:
raise ValueError(
f"Received an invalid value for argument `units`, expected a positive integer, got {units}"
)
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
self.input_spec = keras.layers.InputSpec(min_ndim=2)
self.supports_masking = True
# Main attributes
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self._kernel_scale = None # Used for when the layer is loaded from file
# Hidden weights/layers
self._activation_norm = None
self._kernel = None
self._bias = None
# Public methods
[docs]
def build(self, input_shape: Tuple[int, ...]):
"""
Create layer weights.
Args:
input_shape: Shape of the input.
"""
input_dim = input_shape[-1]
self._activation_norm = RMSNorm()
self._activation_norm.build(input_shape)
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim, self.units),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
if self.use_bias:
self._bias = self.add_weight(
name="bias",
shape=(self.units,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
self.input_spec = keras.layers.InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
[docs]
def call(
self, inputs: Float[np.ndarray, "batch_size *dims last_dim"]
) -> Float[np.ndarray, "batch_size *dims units"]:
"""
Calling method of the layer.
Args:
inputs: Inputs into the layer.
Returns:
Transformed inputs.
"""
# First normalize the inputs
x_norm = self._activation_norm(inputs)
# Get the quantized arrays
x_quantized, w_quantized, w_scale = self._get_quantized_arrays(x_norm)
# Perform ternary multiplication
x = self._ternary_multiplication(x_quantized, w_quantized, w_scale)
# Then apply bias and activation
if self._bias is not None:
x = ops.add(x, self._bias)
if self.activation is not None:
x = self.activation(x)
return x
[docs]
def compute_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""
Computes the output shape given a tensor of a given shape.
Args:
input_shape: Input shape into the layer.
Returns:
Output shape after passing through the layer.
"""
output_shape = list(input_shape)
output_shape[-1] = self.units
return tuple(output_shape)
[docs]
def save_own_variables(self, store: Dict):
"""
Saves the state of the layer.
Args:
store: Dictionary where the state of the model will be saved.
"""
# Pre-quantize the kernel values
w_scale = self._compute_kernel_scale(self._kernel)
w_quantized = self._quantize_kernel(self._kernel, w_scale)
# Encode the ternary array efficiently
shape, encoded = encode_ternary_array(asnumpy(w_quantized))
# Then store the variables
store["kernel_encoded"] = np.frombuffer(encoded, dtype="uint8")
store["kernel_shape"] = shape
store["kernel_scale"] = w_scale
if self.use_bias:
store["bias"] = self._bias
[docs]
def load_own_variables(self, store: Dict):
"""
Loads the state of the layer.
Args:
store: Dictionary from which the state of the model will be loaded.
Raises:
ValueError: If the layer is missing variables when loading from a file.
"""
# Get the variables from the store first
try:
encoded = store["kernel_encoded"][()].tobytes()
shape = store["kernel_shape"][()]
w_scale = store["kernel_scale"][()]
if self.use_bias:
bias = store["bias"][()]
else:
bias = None
except ValueError as e: # pragma: no cover
raise ValueError("DenseMML layer missing values when loading from file") from e
# Then recover the weights
self._kernel.assign(decode_ternary_array(shape, encoded))
self._kernel_scale = w_scale
self._bias = bias