Text Classification with Transformers

Jun 26, 2024

5 min read

This notebook is largely inspired by the Keras code example Text classification with Transformer by Apoorv Nandan.

In this example, we will do text classification using Keras-MML’s transformer implementation.

Note

We will use the jax backend for faster execution of the code. Feel free to ignore the cell below.

import os
os.environ["KERAS_BACKEND"] = "jax"

Preparing the Data

The dataset we will use is the IMDB movie review dataset. It contains 25000 movies reviews from IMDB, where each review is labeled as having positive or negative sentiment.

The dataset is available for importing in Keras, where, for convience, the reviews have already been preprocessed. Each preprocessed review is encoded as a list of word indices, where the word index of a word indicates the frequency of the word in the dataset. For example, a word that was encoded as 3 would indicate that it is the third most frequent word in the dataset. The index 0 is reserved for padding.

For our purposes, we will consider only the top 20000 words. This will be our vocabulary size (VOCAB_SIZE).

import keras
VOCAB_SIZE = 20000
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)

How many sequences did we load?

print(len(x_train), "training sequences")
print(len(x_val), "validation sequences")
25000 training sequences
25000 validation sequences

We will limit each sequence to a length of 200 (MAX_LEN). This means that words beyond the MAX_LEN mark will be removed, while sequences that are not long enough will be padded to MAX_LEN.

MAX_LEN = 200
x_train = keras.utils.pad_sequences(x_train, maxlen=MAX_LEN)
x_val = keras.utils.pad_sequences(x_val, maxlen=MAX_LEN)

Creating the Model

Keras-MML provides a TransformerBlockMML layer. It acts similarly to the transformer architecture described in Attention Is All You Need and outputs one vector per time step of the input. What results is an embedding that should encode more information about the text that went into the transformer.

import keras_mml

We first specify three hyperparameters for the model.

  • The EMBEDDING_DIM gives the dimensionality of the embedding vector for each token in the sequence.

  • The NUM_HEADS gives the number of heads to use in the multi-head attention part of the transformer layer.

  • The FFN_DIM gives the intermediate (i.e., hidden) layer size of the feed-forward network (FFN) in the transformer.

For this example we elect to choose small numbers.

EMBEDDING_DIM = 32
NUM_HEADS = 2
FFN_DIM = 32

The architecture of our model is as follows.

  • We first create embeddings for the tokens in the sequence. We add embeddings for the positions of the tokens to create an initial embedding.

  • This initial embedding will be fed into the transformer block layer. The output will be refined embeddings that should encode more information about the sentence as a whole.

  • Afterwards we take the mean (i.e., average) across all time steps using a standard GlobalAveragePooling1D layer available in the base Keras package.

  • Finally, we will a use fully-connected network (which is several dense layers) on top of it to classify the sentiment of the review.

We will add some dropout in the final fully-connected network to act as regularization and reduce overfitting.

model = keras.models.Sequential(
    layers=[
        keras.layers.Input(shape=(MAX_LEN,)),
        keras_mml.layers.TokenEmbedding(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM, with_positions=True),
        keras_mml.layers.TransformerBlockMML(EMBEDDING_DIM, FFN_DIM, NUM_HEADS),
        keras.layers.GlobalAveragePooling1D(),
        keras.layers.Dropout(0.1),
        keras.layers.Dense(20, activation="relu"),
        keras.layers.Dropout(0.1),
        keras.layers.Dense(2, activation="softmax")
    ]
)

model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ token_embedding                 │ (None, 200, 32)        │       646,400 │
│ (TokenEmbedding)                │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ transformer_block_mml           │ (None, 200, 32)        │        33,120 │
│ (TransformerBlockMML)           │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling1d        │ (None, 32)             │             0 │
│ (GlobalAveragePooling1D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_2 (Dropout)             │ (None, 32)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 20)             │           660 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_3 (Dropout)             │ (None, 20)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 2)              │            42 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 680,222 (2.59 MB)
 Trainable params: 680,222 (2.59 MB)
 Non-trainable params: 0 (0.00 B)

We will train the model to minimise the categorical crossentropy of the model, where we output the accuracy of the model as a metric for us to monitor.

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

Let’s train the model!

model.fit(
    x_train, y_train, batch_size=32, epochs=3, validation_data=(x_val, y_val)
)
Epoch 1/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 31s 36ms/step - accuracy: 0.6802 - loss: 0.5444 - val_accuracy: 0.8592 - val_loss: 0.3278
Epoch 2/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 24s 28ms/step - accuracy: 0.9318 - loss: 0.1858 - val_accuracy: 0.8576 - val_loss: 0.3708
Epoch 3/3
782/782 ━━━━━━━━━━━━━━━━━━━━ 22s 29ms/step - accuracy: 0.9707 - loss: 0.0848 - val_accuracy: 0.8521 - val_loss: 0.4232
<keras.src.callbacks.history.History at 0x7fdb94ddb3d0>

How well did the model do?

val_loss, val_acc = model.evaluate(x_val, y_val)
print(f"Validation loss:     {val_loss:.5f}")
print(f"Validation accuracy: {val_acc * 100:.2f}%")
782/782 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.8514 - loss: 0.4266
Validation loss:     0.42323
Validation accuracy: 85.21%

Conclusion

In this notebook, we demonstrated how to use Keras-MML’s TransformerBlockMML layer as a matmul-free replacement to the traditional transformer architecture. We used TransformerBlockMML in a text classification example, showing that it performs well in this case.