Autoencoders

Jun 19, 2024

3 min read

In this example we will create an autoencoder using the DenseMML layers available in Keras-MML.

The dataset we will use is the Fashion MNIST dataset. It is intended as a drop-in replacement for the standard MNIST dataset, and shares the same image size and structure of training and testing splits.

Fashion MNIST Dataset Cover

Important

We will be using some plotting utilities for this notebook. Run the command below to install them, then reload the kernel.

%pip install matplotlib~=3.9.0 seaborn~=0.13.2
Requirement already satisfied: matplotlib~=3.9.0 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (3.9.0)
Requirement already satisfied: seaborn~=0.13.2 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (0.13.2)
Requirement already satisfied: contourpy>=1.0.1 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (4.53.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (1.4.5)
Requirement already satisfied: numpy>=1.23 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (24.1)
Requirement already satisfied: pillow>=8 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from matplotlib~=3.9.0) (2.9.0.post0)
Requirement already satisfied: pandas>=1.2 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from seaborn~=0.13.2) (2.2.2)
Requirement already satisfied: pytz>=2020.1 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from pandas>=1.2->seaborn~=0.13.2) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from pandas>=1.2->seaborn~=0.13.2) (2024.1)
Requirement already satisfied: six>=1.5 in /home/vscode/.cache/pypoetry/virtualenvs/keras-matmulless-b9IALFmu-py3.10/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib~=3.9.0) (1.16.0)

[notice] A new release of pip is available: 24.0 -> 24.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.

Setup

First, let’s prepare the imports.

import keras
2024-06-22 06:28:33.273137: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-22 06:28:33.273751: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-22 06:28:33.276194: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-22 06:28:33.307806: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-22 06:28:34.075493: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Define constants relating to the data.

NUM_CLASSES = 10        # 10 distinct classes, 0 to 9
INPUT_SHAPE = (28, 28)  # 28 x 28 greyscale images

Load the data from the fashion_mnist dataset. We don’t actually need the labels in this case.

(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)

Perform some preprocessing.

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

Check the shapes of the arrays that we have so far.

print(x_train.shape)
print(x_test.shape)
(60000, 28, 28)
(10000, 28, 28)

Let’s look at some of the data.

import matplotlib.pyplot as plt


def display_images(images):
    n = len(images)
    plt.figure(figsize=(2*n, 4))
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(images[i])
        plt.title(f"Index {i}")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()
display_images(x_train[:10])
../_images/0bdb45e1333353cdd0db5b01432bab6c15f66646e83db2d58c96ee83e08a3419.png
display_images(x_test[:10])
../_images/05f1ba75b077db00d7a2ef4a5e1bf3643d6cd39fc29618b681c4115f5706b44e.png

Creating the Autoencoder

Like with most other examples, the architecture for the model is simple, just that the we replace Dense layers with DenseMML.

Of course, we need to import keras_mml first.

import keras_mml

Now we define the encoder and decoder halves of the model.

def create_encoder():
    encoder = keras.Sequential(name="Encoder")
    encoder.add(keras.Input((28, 28), name="encoder_input"))
    encoder.add(keras.layers.Flatten())
    encoder.add(keras_mml.layers.DenseMML(512, activation="relu"))
    encoder.add(keras_mml.layers.DenseMML(256, activation="relu"))
    encoder.add(keras_mml.layers.DenseMML(128, activation="relu", name="encoder_output"))
    return encoder


def create_decoder():
    decoder = keras.Sequential(name="Decoder")
    decoder.add(keras.Input((128,), name="decoder_input"))
    decoder.add(keras_mml.layers.DenseMML(256, activation="relu"))
    decoder.add(keras_mml.layers.DenseMML(512, activation="relu"))
    decoder.add(keras.layers.Dense(784, activation="sigmoid"))  # We need un-quantized outputs here
    decoder.add(keras.layers.Reshape((28, 28), name="decoder_output"))
    return decoder


def create_model(encoder, decoder):
    model = keras.Sequential(name="Autoencoder")
    model.add(keras.layers.Input((28, 28), name="input"))
    model.add(encoder)
    model.add(decoder)

    model.compile(
        loss="mse",
        optimizer="adam",
        metrics=["mae"]
    )

    return model


model = create_model(create_encoder(), create_decoder())
model.summary(expand_nested=True)
Model: "Autoencoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ Encoder (Sequential)            │ (None, 128)            │       566,144 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ flatten (Flatten)          │ (None, 784)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ dense_mml (DenseMML)       │ (None, 512)            │       401,920 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ dense_mml_1 (DenseMML)     │ (None, 256)            │       131,328 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ encoder_output (DenseMML)  │ (None, 128)            │        32,896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ Decoder (Sequential)            │ (None, 28, 28)         │       566,800 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ dense_mml_2 (DenseMML)     │ (None, 256)            │        33,024 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ dense_mml_3 (DenseMML)     │ (None, 512)            │       131,584 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ dense (Dense)              │ (None, 784)            │       402,192 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│    └ decoder_output (Reshape)   │ (None, 28, 28)         │             0 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,132,944 (4.32 MB)
 Trainable params: 1,132,944 (4.32 MB)
 Non-trainable params: 0 (0.00 B)

With the model defined, we can train it. We will halt training of the model if the validation loss does not improve drastically.

model.fit(
    x_train, x_train,  # Input should match output
    validation_split=0.1,
    batch_size=256,
    epochs=200,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, min_delta=1e-4, verbose=1),
    ]
)
Epoch 1/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 6s 20ms/step - loss: 0.1282 - mae: 0.3193 - val_loss: 0.0875 - val_mae: 0.2363
Epoch 2/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0871 - mae: 0.2351 - val_loss: 0.0872 - val_mae: 0.2342
Epoch 3/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0814 - mae: 0.2260 - val_loss: 0.0669 - val_mae: 0.1933
Epoch 4/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0662 - mae: 0.1914 - val_loss: 0.0642 - val_mae: 0.1858
Epoch 5/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0640 - mae: 0.1847 - val_loss: 0.0640 - val_mae: 0.1823
Epoch 6/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0635 - mae: 0.1827 - val_loss: 0.0632 - val_mae: 0.1821
Epoch 7/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0630 - mae: 0.1815 - val_loss: 0.0627 - val_mae: 0.1804
Epoch 8/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0623 - mae: 0.1799 - val_loss: 0.0611 - val_mae: 0.1785
Epoch 9/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0595 - mae: 0.1756 - val_loss: 0.0524 - val_mae: 0.1652
Epoch 10/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0509 - mae: 0.1615 - val_loss: 0.0492 - val_mae: 0.1554
Epoch 11/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 5s 18ms/step - loss: 0.0485 - mae: 0.1545 - val_loss: 0.0480 - val_mae: 0.1528
Epoch 12/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0477 - mae: 0.1513 - val_loss: 0.0484 - val_mae: 0.1521
Epoch 13/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0473 - mae: 0.1496 - val_loss: 0.0470 - val_mae: 0.1483
Epoch 14/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0465 - mae: 0.1474 - val_loss: 0.0465 - val_mae: 0.1466
Epoch 15/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0461 - mae: 0.1460 - val_loss: 0.0460 - val_mae: 0.1447
Epoch 16/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0458 - mae: 0.1450 - val_loss: 0.0454 - val_mae: 0.1441
Epoch 17/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0456 - mae: 0.1442 - val_loss: 0.0448 - val_mae: 0.1425
Epoch 18/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0449 - mae: 0.1425 - val_loss: 0.0440 - val_mae: 0.1413
Epoch 19/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0443 - mae: 0.1412 - val_loss: 0.0435 - val_mae: 0.1398
Epoch 20/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0434 - mae: 0.1391 - val_loss: 0.0433 - val_mae: 0.1389
Epoch 21/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0428 - mae: 0.1373 - val_loss: 0.0428 - val_mae: 0.1361
Epoch 22/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0418 - mae: 0.1351 - val_loss: 0.0422 - val_mae: 0.1342
Epoch 23/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0417 - mae: 0.1344 - val_loss: 0.0418 - val_mae: 0.1340
Epoch 24/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0411 - mae: 0.1330 - val_loss: 0.0423 - val_mae: 0.1357
Epoch 25/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0409 - mae: 0.1321 - val_loss: 0.0418 - val_mae: 0.1331
Epoch 26/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0407 - mae: 0.1315 - val_loss: 0.0412 - val_mae: 0.1323
Epoch 27/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0401 - mae: 0.1303 - val_loss: 0.0410 - val_mae: 0.1321
Epoch 28/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0399 - mae: 0.1295 - val_loss: 0.0397 - val_mae: 0.1301
Epoch 29/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0396 - mae: 0.1286 - val_loss: 0.0396 - val_mae: 0.1288
Epoch 30/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0392 - mae: 0.1277 - val_loss: 0.0409 - val_mae: 0.1309
Epoch 31/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0391 - mae: 0.1273 - val_loss: 0.0388 - val_mae: 0.1274
Epoch 32/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0385 - mae: 0.1259 - val_loss: 0.0383 - val_mae: 0.1248
Epoch 33/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0378 - mae: 0.1244 - val_loss: 0.0388 - val_mae: 0.1280
Epoch 34/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0378 - mae: 0.1246 - val_loss: 0.0380 - val_mae: 0.1250
Epoch 35/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0372 - mae: 0.1232 - val_loss: 0.0366 - val_mae: 0.1215
Epoch 36/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0362 - mae: 0.1210 - val_loss: 0.0359 - val_mae: 0.1197
Epoch 37/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0360 - mae: 0.1203 - val_loss: 0.0364 - val_mae: 0.1211
Epoch 38/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0356 - mae: 0.1193 - val_loss: 0.0352 - val_mae: 0.1172
Epoch 39/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0351 - mae: 0.1180 - val_loss: 0.0357 - val_mae: 0.1171
Epoch 40/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0347 - mae: 0.1170 - val_loss: 0.0347 - val_mae: 0.1168
Epoch 41/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0347 - mae: 0.1168 - val_loss: 0.0344 - val_mae: 0.1171
Epoch 42/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0343 - mae: 0.1157 - val_loss: 0.0348 - val_mae: 0.1155
Epoch 43/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0340 - mae: 0.1150 - val_loss: 0.0340 - val_mae: 0.1149
Epoch 44/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0338 - mae: 0.1146 - val_loss: 0.0344 - val_mae: 0.1137
Epoch 45/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0337 - mae: 0.1139 - val_loss: 0.0341 - val_mae: 0.1131
Epoch 46/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0335 - mae: 0.1138 - val_loss: 0.0354 - val_mae: 0.1153
Epoch 47/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0334 - mae: 0.1133 - val_loss: 0.0332 - val_mae: 0.1132
Epoch 48/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0331 - mae: 0.1127 - val_loss: 0.0334 - val_mae: 0.1128
Epoch 49/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0333 - mae: 0.1130 - val_loss: 0.0336 - val_mae: 0.1146
Epoch 50/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0329 - mae: 0.1123 - val_loss: 0.0343 - val_mae: 0.1164
Epoch 51/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0332 - mae: 0.1126 - val_loss: 0.0329 - val_mae: 0.1132
Epoch 52/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0327 - mae: 0.1116 - val_loss: 0.0327 - val_mae: 0.1111
Epoch 53/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0325 - mae: 0.1112 - val_loss: 0.0336 - val_mae: 0.1142
Epoch 54/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0325 - mae: 0.1112 - val_loss: 0.0330 - val_mae: 0.1106
Epoch 55/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0326 - mae: 0.1112 - val_loss: 0.0333 - val_mae: 0.1133
Epoch 56/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0323 - mae: 0.1107 - val_loss: 0.0328 - val_mae: 0.1114
Epoch 57/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0326 - mae: 0.1112 - val_loss: 0.0333 - val_mae: 0.1140
Epoch 58/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0326 - mae: 0.1112 - val_loss: 0.0327 - val_mae: 0.1120
Epoch 59/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0323 - mae: 0.1103 - val_loss: 0.0323 - val_mae: 0.1095
Epoch 60/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0321 - mae: 0.1098 - val_loss: 0.0329 - val_mae: 0.1115
Epoch 61/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0323 - mae: 0.1104 - val_loss: 0.0320 - val_mae: 0.1097
Epoch 62/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0318 - mae: 0.1092 - val_loss: 0.0330 - val_mae: 0.1128
Epoch 63/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0315 - mae: 0.1086 - val_loss: 0.0315 - val_mae: 0.1090
Epoch 64/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0318 - mae: 0.1093 - val_loss: 0.0315 - val_mae: 0.1096
Epoch 65/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0312 - mae: 0.1080 - val_loss: 0.0315 - val_mae: 0.1094
Epoch 66/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0312 - mae: 0.1081 - val_loss: 0.0327 - val_mae: 0.1120
Epoch 67/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0310 - mae: 0.1075 - val_loss: 0.0324 - val_mae: 0.1093
Epoch 68/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0309 - mae: 0.1072 - val_loss: 0.0322 - val_mae: 0.1103
Epoch 69/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0307 - mae: 0.1069 - val_loss: 0.0302 - val_mae: 0.1063
Epoch 70/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0303 - mae: 0.1062 - val_loss: 0.0313 - val_mae: 0.1091
Epoch 71/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0302 - mae: 0.1060 - val_loss: 0.0315 - val_mae: 0.1094
Epoch 72/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0302 - mae: 0.1060 - val_loss: 0.0298 - val_mae: 0.1052
Epoch 73/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0296 - mae: 0.1049 - val_loss: 0.0302 - val_mae: 0.1066
Epoch 74/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0295 - mae: 0.1047 - val_loss: 0.0298 - val_mae: 0.1037
Epoch 75/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0292 - mae: 0.1039 - val_loss: 0.0291 - val_mae: 0.1053
Epoch 76/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0290 - mae: 0.1037 - val_loss: 0.0286 - val_mae: 0.1022
Epoch 77/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0286 - mae: 0.1025 - val_loss: 0.0339 - val_mae: 0.1124
Epoch 78/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0287 - mae: 0.1026 - val_loss: 0.0294 - val_mae: 0.1033
Epoch 79/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0285 - mae: 0.1025 - val_loss: 0.0293 - val_mae: 0.1047
Epoch 80/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0284 - mae: 0.1020 - val_loss: 0.0297 - val_mae: 0.1044
Epoch 81/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0284 - mae: 0.1018 - val_loss: 0.0279 - val_mae: 0.1009
Epoch 82/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0279 - mae: 0.1009 - val_loss: 0.0279 - val_mae: 0.1002
Epoch 83/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0280 - mae: 0.1011 - val_loss: 0.0278 - val_mae: 0.1007
Epoch 84/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - loss: 0.0278 - mae: 0.1006 - val_loss: 0.0277 - val_mae: 0.1010
Epoch 85/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0277 - mae: 0.1003 - val_loss: 0.0274 - val_mae: 0.0998
Epoch 86/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - loss: 0.0277 - mae: 0.1003 - val_loss: 0.0278 - val_mae: 0.1005
Epoch 87/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0278 - mae: 0.1006 - val_loss: 0.0301 - val_mae: 0.1060
Epoch 88/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0275 - mae: 0.0999 - val_loss: 0.0275 - val_mae: 0.1001
Epoch 89/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0276 - mae: 0.1001 - val_loss: 0.0282 - val_mae: 0.1020
Epoch 90/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0275 - mae: 0.0998 - val_loss: 0.0274 - val_mae: 0.0989
Epoch 91/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0275 - mae: 0.0998 - val_loss: 0.0281 - val_mae: 0.1019
Epoch 92/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0274 - mae: 0.0994 - val_loss: 0.0275 - val_mae: 0.1001
Epoch 93/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - loss: 0.0276 - mae: 0.0999 - val_loss: 0.0281 - val_mae: 0.0995
Epoch 94/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0272 - mae: 0.0991 - val_loss: 0.0270 - val_mae: 0.0987
Epoch 95/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0273 - mae: 0.0993 - val_loss: 0.0269 - val_mae: 0.0982
Epoch 96/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0274 - mae: 0.0996 - val_loss: 0.0275 - val_mae: 0.0989
Epoch 97/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0272 - mae: 0.0991 - val_loss: 0.0277 - val_mae: 0.1009
Epoch 98/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0274 - mae: 0.0994 - val_loss: 0.0275 - val_mae: 0.0988
Epoch 99/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0273 - mae: 0.0990 - val_loss: 0.0273 - val_mae: 0.0979
Epoch 100/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0269 - mae: 0.0981 - val_loss: 0.0278 - val_mae: 0.1005
Epoch 101/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0270 - mae: 0.0984 - val_loss: 0.0269 - val_mae: 0.0975
Epoch 102/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0270 - mae: 0.0983 - val_loss: 0.0274 - val_mae: 0.1011
Epoch 103/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - loss: 0.0274 - mae: 0.0991 - val_loss: 0.0273 - val_mae: 0.0986
Epoch 104/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0271 - mae: 0.0985 - val_loss: 0.0276 - val_mae: 0.1000
Epoch 105/200
211/211 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.0270 - mae: 0.0983 - val_loss: 0.0272 - val_mae: 0.0988
Epoch 105: early stopping
<keras.src.callbacks.history.History at 0x7fd02c67a0e0>

How well did the model do?

test_mse, test_mae = model.evaluate(x_test, x_test)
print(f"Testing MSE: {test_mse:.5f}")
print(f"Testing MAE: {test_mae:.5f}")
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0268 - mae: 0.0980
Testing MSE: 0.02703
Testing MAE: 0.09843

Experiments with the Trained Model

Let’s grab one of the testing samples.

testing_sample = x_test[1234]

What does it look like?

display_images([testing_sample])
../_images/a17790151bca7160ac92282989e7e6415d95cbdbf455e1813b8bf5f13d70221f.png

Let’s pass it through the autoencoder and see its output.

transformed_testing_sample = model.predict(testing_sample.reshape((1, 28, 28)))[0]
display_images([transformed_testing_sample])
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 148ms/step
../_images/51d01a6227da1efd5f1e9df97682f9b0908fd6365e5d30f739d42f36d4183545.png

Let’s run it on a whole range of samples at once.

testing_samples = x_test[123:123+10]
display_images(testing_samples)
../_images/913d83daef106ab439ca1a825039680765761d356d07b2dd8c0ce9fc8b82b7b7.png
transformed_testing_samples = model.predict(testing_samples.reshape((10, 28, 28)))
display_images(transformed_testing_samples)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 148ms/step
../_images/072274ce3beba4a8b18f89a207322223924ac7c568e9ab77f5c489a11efb3349.png

Conclusion

In this example, we have seen how DenseMML acts as a suitable replacement for Dense layers in a traditional autoencoder. They seem to perform decently well in retaining the important information within their hidden layers.