Credit Card Fraud Detection¶
This notebook is largely inspired by the Keras code example Imbalanced classification: credit card fraud detection by fchollet.
It essential that credit card companies can detect fraudulent transactions using credit cards so that customers are not charged for items that they did not buy. This example looks at the Kaggle Credit Card Fraud Detection dataset to demonstrate how to train a classification model on data with highly imbalanced classes.
Note
We will use the jax backend for faster execution of the code. Feel free to ignore the cell below.
import os
os.environ["KERAS_BACKEND"] = "jax"
Preparing the Data¶
Loading the Data¶
The dataset we will be using is the Kaggle Credit Card Fraud Detection dataset. To access it, you will need a Kaggle account.
The dataset contains transactions made by credit cards in September 2013 by European cardholders over two days, where there are 492 frauds out of 284,807 transactions. The dataset is highly unbalanced — the fraudulent transactions account for only 0.172% of all transactions. Despite this class imbalance, we will try to create a model that detects fraud.
The dataset contains only numerical input variables which are the result of a Principal Component Analysis (PCA) transformation. However, the real features used to generate the original, unmodified dataset are not given; the features given here are the principal components obtained with PCA. The only unchanged columns are Time and Amount.
The
Timeis the seconds elapsed between each transaction and the first transaction in the dataset.The
Amountis the transaction amount.
Our aim is to predict the Class label, where 1 reflects a fraudulent transaction and 0 otherwise.
The dataset is saved in the file called creditcard.csv in the folder data. We will first vectorize the data.
FILE_NAME = "data/creditcard.csv"
import numpy as np
all_features = []
all_targets = []
with open(FILE_NAME) as f:
for i, line in enumerate(f):
# We will skip the first line, which is the header
if i == 0:
# Skip the header
print("HEADER:", line.strip())
continue
# Get the fields of that row
fields = line.strip().split(",")
all_features.append([float(v.replace('"', "")) for v in fields[:-1]])
all_targets.append([int(fields[-1].replace('"', ""))])
# Print the first line as an example of what features we have
if i == 1:
print("EXAMPLE FEATURES:", all_features[-1])
features = np.array(all_features, dtype="float32")
targets = np.array(all_targets, dtype="uint8")
print("Shape of features:", features.shape)
print("Shape of targets: ", targets.shape)
HEADER: "Time","V1","V2","V3","V4","V5","V6","V7","V8","V9","V10","V11","V12","V13","V14","V15","V16","V17","V18","V19","V20","V21","V22","V23","V24","V25","V26","V27","V28","Amount","Class"
EXAMPLE FEATURES: [0.0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, 149.62]
Shape of features: (284807, 30)
Shape of targets: (284807, 1)
Preprocessing the Data¶
First, we will split the data into training and validation datasets. The proportion of data that goes into each of the datasets will be controlled by the VAL_SPLIT constant.
VAL_SPLIT = 0.2
num_val_samples = int(len(features) * VAL_SPLIT)
train_features = features[:-num_val_samples]
train_targets = targets[:-num_val_samples]
val_features = features[-num_val_samples:]
val_targets = targets[-num_val_samples:]
print("Number of training samples:", len(train_features))
print("Number of validation samples:", len(val_features))
Number of training samples: 227846
Number of validation samples: 56961
Let’s now analyse how bad the class imbalances are in the training dataset.
counts = np.bincount(train_targets[:, 0])
print(
f"Number of fraudulant samples in training data: {counts[1]} ({100 * float(counts[1]) / len(train_targets):.2f}% of total)"
)
Number of fraudulant samples in training data: 417 (0.18% of total)
We will assign the weight for the classes using the inverse of the counts present in the training dataset.
weight_for_0 = 1.0 / counts[0]
weight_for_1 = 1.0 / counts[1]
print("Weight for normal transactions: ", weight_for_0)
print("Weight for fraudulant transactions: ", weight_for_1)
Weight for normal transactions: 4.396976638863118e-06
Weight for fraudulant transactions: 0.002398081534772182
Lastly, we will normalize the data using the statistics in the training data.
mean = np.mean(train_features, axis=0)
std = np.std(train_features, axis=0)
train_features = (train_features - mean) / std
val_features = (val_features - mean) / std
Creating the Model¶
The model architecture that we will be using here is nothing special. It is a standard fully-connected network with a classification head at the end. We use DenseMML layers for the hidden layers, but use the standard Dense layer for the classification head. We add some Dropout layers to prevent the model from overfitting.
import keras
import keras_mml
model = keras.Sequential(
[
keras.Input(shape=train_features.shape[1:]),
keras_mml.layers.DenseMML(256, activation="relu"),
keras_mml.layers.DenseMML(256, activation="relu"),
keras.layers.Dropout(0.3),
keras_mml.layers.DenseMML(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.summary()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense_mml (DenseMML) │ (None, 256) │ 7,966 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_mml_1 (DenseMML) │ (None, 256) │ 66,048 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_mml_2 (DenseMML) │ (None, 256) │ 66,048 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 1) │ 257 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 140,319 (548.12 KB)
Trainable params: 140,319 (548.12 KB)
Non-trainable params: 0 (0.00 B)
We compile the model with an aim of minimising binary_crossentropy loss using the Adam optimizer. For our metrics, we will monitor
the number of false negatives;
the number of false positives;
the number of true negatives;
the number of true positives;
the precision of the model, which is given by
where \(\mathrm{TP}\) is the number of true positives and \(\mathrm{FP}\) is the number of false positives; and\[\mathrm{Precision} = \frac{\mathrm{TP}}{\mathrm{TP} + \mathrm{FP}}\]the recall of the model, which is given by
where \(\mathrm{TP}\) is the number of true positives and \(\mathrm{FN}\) is the number of false negatives.\[\mathrm{Recall} = \frac{\mathrm{TP}}{\mathrm{TP} + \mathrm{FN}}\]
model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=[
keras.metrics.FalseNegatives(name="fn"),
keras.metrics.FalsePositives(name="fp"),
keras.metrics.TrueNegatives(name="tn"),
keras.metrics.TruePositives(name="tp"),
keras.metrics.Precision(name="precision"),
keras.metrics.Recall(name="recall"),
]
)
We will weight the classes based on the weight_for_0 and weight_for_1 calculated in the previous section.
class_weight = {0: weight_for_0, 1: weight_for_1}
With all these defined, we can start the model training!
model.fit(
train_features,
train_targets,
batch_size=2048,
epochs=30,
validation_data=(val_features, val_targets),
class_weight=class_weight,
)
Epoch 1/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - fn: 50.9911 - fp: 37535.1875 - loss: 5.3293e-06 - precision: 0.0041 - recall: 0.7579 - tn: 78943.3047 - tp: 161.3186 - val_fn: 8.0000 - val_fp: 7554.0000 - val_loss: 0.4601 - val_precision: 0.0088 - val_recall: 0.8933 - val_tn: 49332.0000 - val_tp: 67.0000
Epoch 2/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - fn: 30.9469 - fp: 9665.1680 - loss: 3.0677e-06 - precision: 0.0167 - recall: 0.8398 - tn: 106811.8438 - tp: 182.8407 - val_fn: 10.0000 - val_fp: 1416.0000 - val_loss: 0.2439 - val_precision: 0.0439 - val_recall: 0.8667 - val_tn: 55470.0000 - val_tp: 65.0000
Epoch 3/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 28.1947 - fp: 4317.9468 - loss: 2.2399e-06 - precision: 0.0414 - recall: 0.8734 - tn: 112158.9062 - tp: 185.7522 - val_fn: 11.0000 - val_fp: 1057.0000 - val_loss: 0.1491 - val_precision: 0.0571 - val_recall: 0.8533 - val_tn: 55829.0000 - val_tp: 64.0000
Epoch 4/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 21.2920 - fp: 3781.5222 - loss: 1.9188e-06 - precision: 0.0461 - recall: 0.8919 - tn: 112704.6797 - tp: 183.3009 - val_fn: 8.0000 - val_fp: 1897.0000 - val_loss: 0.1507 - val_precision: 0.0341 - val_recall: 0.8933 - val_tn: 54989.0000 - val_tp: 67.0000
Epoch 5/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 22.0973 - fp: 4467.4424 - loss: 1.7561e-06 - precision: 0.0405 - recall: 0.8873 - tn: 112006.0938 - tp: 195.1593 - val_fn: 10.0000 - val_fp: 1079.0000 - val_loss: 0.0974 - val_precision: 0.0568 - val_recall: 0.8667 - val_tn: 55807.0000 - val_tp: 65.0000
Epoch 6/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 18.5221 - fp: 3918.7522 - loss: 1.4756e-06 - precision: 0.0479 - recall: 0.9141 - tn: 112558.3984 - tp: 195.1239 - val_fn: 10.0000 - val_fp: 925.0000 - val_loss: 0.0729 - val_precision: 0.0657 - val_recall: 0.8667 - val_tn: 55961.0000 - val_tp: 65.0000
Epoch 7/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 22.1947 - fp: 3846.5132 - loss: 1.4187e-06 - precision: 0.0482 - recall: 0.8791 - tn: 112633.3359 - tp: 188.7522 - val_fn: 9.0000 - val_fp: 1148.0000 - val_loss: 0.0740 - val_precision: 0.0544 - val_recall: 0.8800 - val_tn: 55738.0000 - val_tp: 66.0000
Epoch 8/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 20.9204 - fp: 4184.3716 - loss: 1.4327e-06 - precision: 0.0428 - recall: 0.8981 - tn: 112302.2578 - tp: 183.2478 - val_fn: 9.0000 - val_fp: 755.0000 - val_loss: 0.0541 - val_precision: 0.0804 - val_recall: 0.8800 - val_tn: 56131.0000 - val_tp: 66.0000
Epoch 9/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 13.2743 - fp: 3506.1416 - loss: 1.1186e-06 - precision: 0.0584 - recall: 0.9442 - tn: 112967.1875 - tp: 204.1947 - val_fn: 9.0000 - val_fp: 1405.0000 - val_loss: 0.0907 - val_precision: 0.0449 - val_recall: 0.8800 - val_tn: 55481.0000 - val_tp: 66.0000
Epoch 10/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 14.0708 - fp: 4288.3452 - loss: 1.0693e-06 - precision: 0.0439 - recall: 0.9451 - tn: 112186.6172 - tp: 201.7611 - val_fn: 9.0000 - val_fp: 1396.0000 - val_loss: 0.0858 - val_precision: 0.0451 - val_recall: 0.8800 - val_tn: 55490.0000 - val_tp: 66.0000
Epoch 11/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 13.1062 - fp: 3734.9026 - loss: 1.0049e-06 - precision: 0.0506 - recall: 0.9414 - tn: 112742.6719 - tp: 200.1151 - val_fn: 9.0000 - val_fp: 1307.0000 - val_loss: 0.0808 - val_precision: 0.0481 - val_recall: 0.8800 - val_tn: 55579.0000 - val_tp: 66.0000
Epoch 12/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 17.5310 - fp: 4386.2212 - loss: 1.5879e-06 - precision: 0.0430 - recall: 0.9047 - tn: 112088.4453 - tp: 198.6018 - val_fn: 10.0000 - val_fp: 450.0000 - val_loss: 0.0415 - val_precision: 0.1262 - val_recall: 0.8667 - val_tn: 56436.0000 - val_tp: 65.0000
Epoch 13/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 13.4425 - fp: 3037.5930 - loss: 1.0518e-06 - precision: 0.0727 - recall: 0.9453 - tn: 113431.5078 - tp: 208.2566 - val_fn: 9.0000 - val_fp: 1172.0000 - val_loss: 0.0739 - val_precision: 0.0533 - val_recall: 0.8800 - val_tn: 55714.0000 - val_tp: 66.0000
Epoch 14/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 12.6106 - fp: 3594.8672 - loss: 1.0163e-06 - precision: 0.0513 - recall: 0.9439 - tn: 112885.8906 - tp: 197.4248 - val_fn: 8.0000 - val_fp: 1115.0000 - val_loss: 0.0720 - val_precision: 0.0567 - val_recall: 0.8933 - val_tn: 55771.0000 - val_tp: 67.0000
Epoch 15/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 13.7345 - fp: 4030.9292 - loss: 1.1193e-06 - precision: 0.0445 - recall: 0.9352 - tn: 112452.0938 - tp: 194.0354 - val_fn: 8.0000 - val_fp: 698.0000 - val_loss: 0.0547 - val_precision: 0.0876 - val_recall: 0.8933 - val_tn: 56188.0000 - val_tp: 67.0000
Epoch 16/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 10.4867 - fp: 2935.1062 - loss: 8.6874e-07 - precision: 0.0643 - recall: 0.9585 - tn: 113542.8281 - tp: 202.3717 - val_fn: 8.0000 - val_fp: 823.0000 - val_loss: 0.0573 - val_precision: 0.0753 - val_recall: 0.8933 - val_tn: 56063.0000 - val_tp: 67.0000
Epoch 17/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 13.3982 - fp: 3766.9824 - loss: 9.2654e-07 - precision: 0.0580 - recall: 0.9424 - tn: 112701.5391 - tp: 208.8761 - val_fn: 8.0000 - val_fp: 567.0000 - val_loss: 0.0432 - val_precision: 0.1057 - val_recall: 0.8933 - val_tn: 56319.0000 - val_tp: 67.0000
Epoch 18/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 9.5398 - fp: 2723.8142 - loss: 8.6890e-07 - precision: 0.0684 - recall: 0.9616 - tn: 113762.1797 - tp: 195.2655 - val_fn: 8.0000 - val_fp: 1143.0000 - val_loss: 0.0651 - val_precision: 0.0554 - val_recall: 0.8933 - val_tn: 55743.0000 - val_tp: 67.0000
Epoch 19/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 8.4248 - fp: 3825.5044 - loss: 8.9759e-07 - precision: 0.0469 - recall: 0.9639 - tn: 112655.5859 - tp: 201.2832 - val_fn: 8.0000 - val_fp: 1183.0000 - val_loss: 0.0702 - val_precision: 0.0536 - val_recall: 0.8933 - val_tn: 55703.0000 - val_tp: 67.0000
Epoch 20/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 11.3363 - fp: 3549.0354 - loss: 9.2118e-07 - precision: 0.0516 - recall: 0.9475 - tn: 112929.7109 - tp: 200.7168 - val_fn: 8.0000 - val_fp: 849.0000 - val_loss: 0.0567 - val_precision: 0.0731 - val_recall: 0.8933 - val_tn: 56037.0000 - val_tp: 67.0000
Epoch 21/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 10.0177 - fp: 3406.5310 - loss: 9.2069e-07 - precision: 0.0557 - recall: 0.9581 - tn: 113073.7891 - tp: 200.4602 - val_fn: 8.0000 - val_fp: 1092.0000 - val_loss: 0.0653 - val_precision: 0.0578 - val_recall: 0.8933 - val_tn: 55794.0000 - val_tp: 67.0000
Epoch 22/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 8.5575 - fp: 3299.6992 - loss: 7.7663e-07 - precision: 0.0545 - recall: 0.9581 - tn: 113181.1094 - tp: 201.4336 - val_fn: 8.0000 - val_fp: 1064.0000 - val_loss: 0.0595 - val_precision: 0.0592 - val_recall: 0.8933 - val_tn: 55822.0000 - val_tp: 67.0000
Epoch 23/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 8.6814 - fp: 3343.1504 - loss: 7.6968e-07 - precision: 0.0584 - recall: 0.9669 - tn: 113132.6406 - tp: 206.3274 - val_fn: 8.0000 - val_fp: 942.0000 - val_loss: 0.0532 - val_precision: 0.0664 - val_recall: 0.8933 - val_tn: 55944.0000 - val_tp: 67.0000
Epoch 24/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 9.3186 - fp: 3097.9734 - loss: 7.4687e-07 - precision: 0.0615 - recall: 0.9630 - tn: 113381.9375 - tp: 201.5664 - val_fn: 9.0000 - val_fp: 552.0000 - val_loss: 0.0369 - val_precision: 0.1068 - val_recall: 0.8800 - val_tn: 56334.0000 - val_tp: 66.0000
Epoch 25/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 7.2655 - fp: 2699.3186 - loss: 6.1440e-07 - precision: 0.0702 - recall: 0.9636 - tn: 113776.6250 - tp: 207.5841 - val_fn: 7.0000 - val_fp: 1497.0000 - val_loss: 0.0760 - val_precision: 0.0435 - val_recall: 0.9067 - val_tn: 55389.0000 - val_tp: 68.0000
Epoch 26/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 7.8673 - fp: 3725.1150 - loss: 6.9181e-07 - precision: 0.0501 - recall: 0.9624 - tn: 112744.1484 - tp: 213.6637 - val_fn: 8.0000 - val_fp: 835.0000 - val_loss: 0.0479 - val_precision: 0.0743 - val_recall: 0.8933 - val_tn: 56051.0000 - val_tp: 67.0000
Epoch 27/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 7.0619 - fp: 3058.2654 - loss: 6.9943e-07 - precision: 0.0615 - recall: 0.9623 - tn: 113417.6484 - tp: 207.8230 - val_fn: 9.0000 - val_fp: 541.0000 - val_loss: 0.0337 - val_precision: 0.1087 - val_recall: 0.8800 - val_tn: 56345.0000 - val_tp: 66.0000
Epoch 28/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 6.3274 - fp: 2678.2478 - loss: 6.3577e-07 - precision: 0.0732 - recall: 0.9725 - tn: 113797.7266 - tp: 208.4956 - val_fn: 8.0000 - val_fp: 653.0000 - val_loss: 0.0377 - val_precision: 0.0931 - val_recall: 0.8933 - val_tn: 56233.0000 - val_tp: 67.0000
Epoch 29/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 8.9115 - fp: 3394.3274 - loss: 7.7616e-07 - precision: 0.0519 - recall: 0.9585 - tn: 113092.3906 - tp: 195.1681 - val_fn: 10.0000 - val_fp: 408.0000 - val_loss: 0.0258 - val_precision: 0.1374 - val_recall: 0.8667 - val_tn: 56478.0000 - val_tp: 65.0000
Epoch 30/30
112/112 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - fn: 6.4867 - fp: 2708.5930 - loss: 5.8697e-07 - precision: 0.0771 - recall: 0.9755 - tn: 113766.3516 - tp: 209.3628 - val_fn: 8.0000 - val_fp: 580.0000 - val_loss: 0.0325 - val_precision: 0.1036 - val_recall: 0.8933 - val_tn: 56306.0000 - val_tp: 67.0000
<keras.src.callbacks.history.History at 0x7fa3f4536170>
With the model trained, how does it do on the validation dataset?
val_loss, val_fn, val_fp, val_tn, val_tp, val_precision, val_recall = model.evaluate(val_features, val_targets)
print("--- Validation Statistics ---")
print("Loss: ", val_loss)
print("False Negatives:", int(val_fn))
print("False Positives:", int(val_fp))
print("True Negatives: ", int(val_tn))
print("True Positives: ", int(val_tp))
print("Precision: ", f"{val_precision * 100:.3f}%")
print("Recall: ", f"{val_recall * 100:.3f}%")
1781/1781 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - fn: 3.1655 - fp: 298.5135 - loss: 0.0364 - precision: 0.1158 - recall: 0.9160 - tn: 28185.0996 - tp: 41.1689
--- Validation Statistics ---
Loss: 0.03248622640967369
False Negatives: 8
False Positives: 580
True Negatives: 56306
True Positives: 67
Precision: 10.355%
Recall: 89.333%
Conclusion¶
At the end of training, out of 56,961 validation transactions, we
correctly identify 67 of them as fraudulent;
miss 8 fraudulant transactions; and
incorrectly flag 580 legitimate transactions.
In practice, one would put an even higher weight on class 1 (i.e., the fraudulant class) so as to reflect that false negatives are more costly than false positives.
Regardless, this example shows how DenseMML can be used as a replacement for Dense layers in classification models, even if the classes present are imbalanced.