Understanding the Call Syntax¶
This guide will help you understand the calling syntax of the layers.
Why?¶
We use the jaxtyping package to help write type annotations for arrays. Specifically, these type annotations specify the shape and type of the arrays (actually tensors) that these layers accept.
How To Interpret¶
Consider the following call signature.
- SomeLayer.call(self, inputs)¶
Calling method of the layer.
- Parameters:
inputs (
Float[np.ndarray, 'batch_size *dims last_dim']) – Inputs into the layer.- Returns:
Float[np.ndarray, 'batch_size *dims units']– Transformed inputs.
Let us unpack what the calling syntax Float[np.ndarray, 'batch_size *dims last_dim'] means.
The
Floatmeans that the array (tensor) is supposed to contain floating point values.np.ndarrayindicates that we are using an array.'batch_size *dims last_dim'specifies the shape of the array.The shape is a string of space-separated symbols, such as
'a b c d'. Each symbol represents a separate axis.An axis that is prepended with
*(like*dimsin the above example) means that it can be used to match multiple axes (or no axis). Thus*dimsmatches any intermediate axes.
So the shape specified by
'batch_size *dims last_dim'is one thattakes a variable
batch_sizeas the first axis;accepts any number of variable axes as the intermediate
dims; andaccepts a variable
last_dimas the last axis.
Now, referring to the return type Float[np.ndarray, 'batch_size *dims units'],
the
Floatmeans that a tensor of floats will be returned;Arrayindicates that we are returning an array; and'batch_size *dims units'indicates that the shape of the array is almost the same as the input array, except that the last dimension is changed to be the value ofunits.