"""Implements a matmul-less Gated Linear Unit (GLU) layer."""fromtypingimportAny,Dict,Optional,TupleimportkerasimportnumpyasnpfromjaxtypingimportFloatfromkerasimportactivations,opsfromkeras_mml.layers.coreimportDenseMML#: Set of activations that can be used with :py:class:`~GLUMML`.PERMITTED_ACTIVATIONS={"sigmoid","linear","relu","gelu","silu","selu"}
[docs]@keras.saving.register_keras_serializable(package="keras_mml")classGLUMML(keras.Layer):""" General Gated Linear Unit (GLU) without matrix multiplications. This is a modified implementation of ``HGRNBitMLP`` from the `GitHub repository <https://github.com/ridgerchu/matmulfreellm>`_ of |MatMulFreeLLM|_ where, instead of just permitting the Swish activation, we permit other activations via the :py:attr:`~GLUMML.activation` attribute. See section 3.3.2 of the aforementioned paper for the notation used in the implementation of the code. Attributes: units: Dimensionality of the output space. hidden_ratio: Ratio adjusting the intermediate size. intermediate_size: Intermediate size. See the :py:func:`~GLUMML.__init__` method on how the intermediate size is determined. activation: GLU activation function. .. |MatMulFreeLLM| replace:: *Scalable MatMul-free Language Modeling* .. _MatMulFreeLLM: https://arxiv.org/pdf/2406.02528v5 """
[docs]def__init__(self,units:int,hidden_ratio:int=4,intermediate_size:Optional[int]=None,activation:str="sigmoid",**kwargs,):""" Initializes a new instance of the layer. Args: units: Dimensionality of the output space. hidden_ratio: Ratio adjusting the intermediate size. Ignored if an intermediate size is specified. intermediate_size: Intermediate size. If None, will choose a multiple of 256 closest to :math:`\\frac23 lr` where :math:`l` is the hidden shape given by the input into the layer and :math:`r` is the :py:attr:`~GLUMML.hidden_ratio`. activation: GLU activation function. **kwargs: Keyword arguments for :py:class:`keras.Layer`. Raises: ValueError: If the units provided is not a positive integer. ValueError: If the activation function specified is not in the :py:const:`~PERMITTED_ACTIVATIONS`. """ifunits<=0:raiseValueError(f"Received an invalid value for units, expected a positive integer, got {units}")ifactivationnotinPERMITTED_ACTIVATIONS:raiseValueError(f"GLU activation '{activation}' not allowed; permitted activations are {PERMITTED_ACTIVATIONS}")super().__init__(**kwargs)# Main attributesself.units=unitsself.hidden_ratio=hidden_ratioself.intermediate_size=intermediate_sizeself.activation=activations.get(activation)# Hidden weights/layersself._gate_dense=Noneself._down_dense=None
# Public methods
[docs]defbuild(self,input_shape:Tuple[int,...]):""" Create layer weights. Args: input_shape: Shape of the input. """hidden_size=input_shape[-1]ifself.intermediate_sizeisNone:# The `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`intermediate_size=int(hidden_size*self.hidden_ratio*2/3)intermediate_size=256*((intermediate_size+256-1)//256)self.intermediate_size=intermediate_sizeself._gate_dense=DenseMML(self.intermediate_size*2)# We will use this layer for both $W_g$ and $W_u$self._gate_dense.build(input_shape)self._down_dense=DenseMML(self.units)# We will use this layer for $W_d$self._down_dense.build((None,self.intermediate_size))self.built=True
[docs]defcall(self,inputs:Float[np.ndarray,"batch_size *dims last_dim"])->Float[np.ndarray,"batch_size *dims units"]:""" Calling method of the layer. Args: inputs: Inputs into the layer. Returns: Transformed inputs. """g_and_u=self._gate_dense(inputs)g,u=ops.split(g_and_u,2,axis=-1)p=ops.multiply(self.activation(g),u)d=self._down_dense(p)returnd
[docs]defcompute_output_shape(self,input_shape:Tuple[int,...])->Tuple[int,...]:""" Computes the output shape of the layer. Args: input_shape: Shape of the input into the layer. Returns: Shape of the output. """input_shape=list(input_shape)input_shape[-1]=self.unitsreturntuple(input_shape)
[docs]defget_config(self)->Dict[str,Any]:""" Gets the configuration for the layer. Returns: Layer configuration. """config=super().get_config()config.update({"units":self.units,"hidden_ratio":self.hidden_ratio,"intermediate_size":self.intermediate_size,"activation":activations.serialize(self.activation),})returnconfig