Text Classification with Transformers¶
This notebook is largely inspired by the Keras code example Text classification with Transformer by Apoorv Nandan.
In this example, we will do text classification using Keras-MML’s transformer implementation.
Note
We will use the jax backend for faster execution of the code. Feel free to ignore the cell below.
import os
os.environ["KERAS_BACKEND"] = "jax"
Preparing the Data¶
The dataset we will use is the IMDB movie review dataset. It contains 25000 movies reviews from IMDB, where each review is labeled as having positive or negative sentiment.
The dataset is available for importing in Keras, where, for convience, the reviews have already been preprocessed. Each preprocessed review is encoded as a list of word indices, where the word index of a word indicates the frequency of the word in the dataset. For example, a word that was encoded as 3 would indicate that it is the third most frequent word in the dataset. The index 0 is reserved for padding.
For our purposes, we will consider only the top 20000 words. This will be our vocabulary size (VOCAB_SIZE).
import keras
VOCAB_SIZE = 20000
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
How many sequences did we load?
print(len(x_train), "training sequences")
print(len(x_val), "validation sequences")
25000 training sequences
25000 validation sequences
We will limit each sequence to a length of 200 (MAX_LEN). This means that words beyond the MAX_LEN mark will be removed, while sequences that are not long enough will be padded to MAX_LEN.
MAX_LEN = 200
x_train = keras.utils.pad_sequences(x_train, maxlen=MAX_LEN)
x_val = keras.utils.pad_sequences(x_val, maxlen=MAX_LEN)
Creating the Model¶
Keras-MML provides a TransformerBlockMML layer. It acts similarly to the transformer architecture described in Attention Is All You Need and outputs one vector per time step of the input. What results is an embedding that should encode more information about the text that went into the transformer.
import keras_mml
We first specify three hyperparameters for the model.
The
EMBEDDING_DIMgives the dimensionality of the embedding vector for each token in the sequence.The
NUM_HEADSgives the number of heads to use in the multi-head attention part of the transformer layer.The
FFN_DIMgives the intermediate (i.e., hidden) layer size of the feed-forward network (FFN) in the transformer.
For this example we elect to choose small numbers.
EMBEDDING_DIM = 32
NUM_HEADS = 2
FFN_DIM = 32
The architecture of our model is as follows.
We first create embeddings for the tokens in the sequence. We add embeddings for the positions of the tokens to create an initial embedding.
This initial embedding will be fed into the transformer block layer. The output will be refined embeddings that should encode more information about the sentence as a whole.
Afterwards we take the mean (i.e., average) across all time steps using a standard
GlobalAveragePooling1Dlayer available in the base Keras package.Finally, we will a use fully-connected network (which is several dense layers) on top of it to classify the sentiment of the review.
We will add some dropout in the final fully-connected network to act as regularization and reduce overfitting.
model = keras.models.Sequential(
layers=[
keras.layers.Input(shape=(MAX_LEN,)),
keras_mml.layers.TokenEmbedding(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM, with_positions=True),
keras_mml.layers.TransformerBlockMML(EMBEDDING_DIM, FFN_DIM, NUM_HEADS),
keras.layers.GlobalAveragePooling1D(),
keras.layers.Dropout(0.1),
keras.layers.Dense(20, activation="relu"),
keras.layers.Dropout(0.1),
keras.layers.Dense(2, activation="softmax")
]
)
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ token_embedding │ (None, 200, 32) │ 646,400 │ │ (TokenEmbedding) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ transformer_block_mml │ (None, 200, 32) │ 33,120 │ │ (TransformerBlockMML) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling1d │ (None, 32) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 20) │ 660 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_3 (Dropout) │ (None, 20) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 2) │ 42 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 680,222 (2.59 MB)
Trainable params: 680,222 (2.59 MB)
Non-trainable params: 0 (0.00 B)
We will train the model to minimise the categorical crossentropy of the model, where we output the accuracy of the model as a metric for us to monitor.
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
Let’s train the model!
model.fit(
x_train, y_train, batch_size=32, epochs=3, validation_data=(x_val, y_val)
)
Epoch 1/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 31s 36ms/step - accuracy: 0.6802 - loss: 0.5444 - val_accuracy: 0.8592 - val_loss: 0.3278
Epoch 2/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 24s 28ms/step - accuracy: 0.9318 - loss: 0.1858 - val_accuracy: 0.8576 - val_loss: 0.3708
Epoch 3/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 22s 29ms/step - accuracy: 0.9707 - loss: 0.0848 - val_accuracy: 0.8521 - val_loss: 0.4232
<keras.src.callbacks.history.History at 0x7fdb94ddb3d0>
How well did the model do?
val_loss, val_acc = model.evaluate(x_val, y_val)
print(f"Validation loss: {val_loss:.5f}")
print(f"Validation accuracy: {val_acc * 100:.2f}%")
782/782 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.8514 - loss: 0.4266
Validation loss: 0.42323
Validation accuracy: 85.21%
Conclusion¶
In this notebook, we demonstrated how to use Keras-MML’s TransformerBlockMML layer as a matmul-free replacement to the traditional transformer architecture. We used TransformerBlockMML in a text classification example, showing that it performs well in this case.