Skip to contents

Introduction

This example looks at the Kaggle Credit Card Fraud Detection dataset to demonstrate how to train a classification model on data with highly imbalanced classes. You can download the data by clicking “Download” at the link, or if you’re setup with a kaggle API key at "~/.kaggle/kagle.json", you can run the following:

reticulate::py_install("kaggle", pip = TRUE)
reticulate::py_available(TRUE) # ensure 'kaggle' is on the PATH
system("kaggle datasets download -d mlg-ulb/creditcardfraud")
zip::unzip("creditcardfraud.zip", files = "creditcard.csv")

First, load the data

library(readr)
df <- read_csv("creditcard.csv", col_types = cols(
  Class = col_integer(),
  .default = col_double()
))
tibble::glimpse(df)
## Rows: 284,807
## Columns: 31
## $ Time   <dbl> 0, 0, 1, 1, 2, 2, 4, 7, 7, 9, 10, 10, 10, 11, 12, 12, 12, 1…
## $ V1     <dbl> -1.3598071, 1.1918571, -1.3583541, -0.9662717, -1.1582331, 
## $ V2     <dbl> -0.07278117, 0.26615071, -1.34016307, -0.18522601, 0.877736…
## $ V3     <dbl> 2.53634674, 0.16648011, 1.77320934, 1.79299334, 1.54871785,
## $ V4     <dbl> 1.37815522, 0.44815408, 0.37977959, -0.86329128, 0.40303393…
## $ V5     <dbl> -0.33832077, 0.06001765, -0.50319813, -0.01030888, -0.40719…
## $ V6     <dbl> 0.46238778, -0.08236081, 1.80049938, 1.24720317, 0.09592146…
## $ V7     <dbl> 0.239598554, -0.078802983, 0.791460956, 0.237608940, 0.5929…
## $ V8     <dbl> 0.098697901, 0.085101655, 0.247675787, 0.377435875, -0.2705…
## $ V9     <dbl> 0.3637870, -0.2554251, -1.5146543, -1.3870241, 0.8177393, -…
## $ V10    <dbl> 0.09079417, -0.16697441, 0.20764287, -0.05495192, 0.7530744…
## $ V11    <dbl> -0.55159953, 1.61272666, 0.62450146, -0.22648726, -0.822842…
## $ V12    <dbl> -0.61780086, 1.06523531, 0.06608369, 0.17822823, 0.53819555…
## $ V13    <dbl> -0.99138985, 0.48909502, 0.71729273, 0.50775687, 1.34585159…
## $ V14    <dbl> -0.31116935, -0.14377230, -0.16594592, -0.28792375, -1.1196…
## $ V15    <dbl> 1.468176972, 0.635558093, 2.345864949, -0.631418118, 0.1751…
## $ V16    <dbl> -0.47040053, 0.46391704, -2.89008319, -1.05964725, -0.45144…
## $ V17    <dbl> 0.207971242, -0.114804663, 1.109969379, -0.684092786, -0.23…
## $ V18    <dbl> 0.02579058, -0.18336127, -0.12135931, 1.96577500, -0.038194…
## $ V19    <dbl> 0.40399296, -0.14578304, -2.26185710, -1.23262197, 0.803486…
## $ V20    <dbl> 0.25141210, -0.06908314, 0.52497973, -0.20803778, 0.4085423…
## $ V21    <dbl> -0.018306778, -0.225775248, 0.247998153, -0.108300452, -0.0…
## $ V22    <dbl> 0.277837576, -0.638671953, 0.771679402, 0.005273597, 0.7982…
## $ V23    <dbl> -0.110473910, 0.101288021, 0.909412262, -0.190320519, -0.13…
## $ V24    <dbl> 0.06692807, -0.33984648, -0.68928096, -1.17557533, 0.141266…
## $ V25    <dbl> 0.12853936, 0.16717040, -0.32764183, 0.64737603, -0.2060095…
## $ V26    <dbl> -0.18911484, 0.12589453, -0.13909657, -0.22192884, 0.502292…
## $ V27    <dbl> 0.133558377, -0.008983099, -0.055352794, 0.062722849, 0.219…
## $ V28    <dbl> -0.021053053, 0.014724169, -0.059751841, 0.061457629, 0.215…
## $ Amount <dbl> 149.62, 2.69, 378.66, 123.50, 69.99, 3.67, 4.99, 40.80, 93.…
## $ Class  <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

Prepare a validation set

val_idx <- nrow(df) %>% sample.int(., round( . * 0.2))
val_df <- df[val_idx, ]
train_df <- df[-val_idx, ]

cat("Number of training samples:", nrow(train_df), "\n")
## Number of training samples: 227846
cat("Number of validation samples:", nrow(val_df), "\n")
## Number of validation samples: 56961

Analyze class imbalance in the targets

counts <- table(train_df$Class)
counts
##
##      0      1
## 227451    395
cat(sprintf("Number of positive samples in training data: %i (%.2f%% of total)",
            counts["1"], 100 * counts["1"] / sum(counts)))
## Number of positive samples in training data: 395 (0.17% of total)
weight_for_0 = 1 / counts["0"]
weight_for_1 = 1 / counts["1"]

Normalize the data using training set statistics

feature_names <- colnames(train_df) %>% setdiff("Class")

train_features <- as.matrix(train_df[feature_names])
train_targets <- as.matrix(train_df$Class)

val_features <- as.matrix(val_df[feature_names])
val_targets <- as.matrix(val_df$Class)

train_features %<>% scale()
val_features %<>% scale(center = attr(train_features, "scaled:center"),
                        scale = attr(train_features, "scaled:scale"))

Build a binary classification model

model <-
  keras_model_sequential(input_shape = ncol(train_features)) |>
  layer_dense(256, activation = "relu") |>
  layer_dense(256, activation = "relu") |>
  layer_dropout(0.3) |>
  layer_dense(256, activation = "relu") |>
  layer_dropout(0.3) |>
  layer_dense(1, activation = "sigmoid")

model
## Model: "sequential"
## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
## ┃ Layer (type)                     Output Shape                  Param # 
## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
## │ dense (Dense)                   │ (None, 256)            │         7,936
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense_1 (Dense)                 │ (None, 256)            │        65,792
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dropout (Dropout)               │ (None, 256)            │             0
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense_2 (Dense)                 │ (None, 256)            │        65,792
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dropout_1 (Dropout)             │ (None, 256)            │             0
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense_3 (Dense)                 │ (None, 1)              │           257
## └─────────────────────────────────┴────────────────────────┴───────────────┘
##  Total params: 139,777 (546.00 KB)
##  Trainable params: 139,777 (546.00 KB)
##  Non-trainable params: 0 (0.00 B)

Train the model with class_weight argument

metrics <- list(
  metric_false_negatives(name = "fn"),
  metric_false_positives(name = "fp"),
  metric_true_negatives(name = "tn"),
  metric_true_positives(name = "tp"),
  metric_precision(name = "precision"),
  metric_recall(name = "recall")
)
model |> compile(
  optimizer = optimizer_adam(1e-2),
  loss = "binary_crossentropy",
  metrics = metrics
)
callbacks <- list(
  callback_model_checkpoint("fraud_model_at_epoch_{epoch}.keras")
)

class_weight <- list("0" = weight_for_0,
                     "1" = weight_for_1)

model |> fit(
  train_features, train_targets,
  validation_data = list(val_features, val_targets),
  class_weight = class_weight,
  batch_size = 2048,
  epochs = 30,
  callbacks = callbacks,
  verbose = 2
)
## Epoch 1/30
## 112/112 - 5s - 41ms/step - fn: 36.0000 - fp: 30458.0000 - loss: 2.3118e-06 - precision: 0.0116 - recall: 0.9089 - tn: 196993.0000 - tp: 359.0000 - val_fn: 14.0000 - val_fp: 714.0000 - val_loss: 0.0906 - val_precision: 0.1041 - val_recall: 0.8557 - val_tn: 56150.0000 - val_tp: 83.0000
## Epoch 2/30
## 112/112 - 1s - 11ms/step - fn: 34.0000 - fp: 5698.0000 - loss: 1.4544e-06 - precision: 0.0596 - recall: 0.9139 - tn: 221753.0000 - tp: 361.0000 - val_fn: 6.0000 - val_fp: 2982.0000 - val_loss: 0.1611 - val_precision: 0.0296 - val_recall: 0.9381 - val_tn: 53882.0000 - val_tp: 91.0000
## Epoch 3/30
## 112/112 - 0s - 2ms/step - fn: 30.0000 - fp: 7494.0000 - loss: 1.1267e-06 - precision: 0.0464 - recall: 0.9241 - tn: 219957.0000 - tp: 365.0000 - val_fn: 9.0000 - val_fp: 2262.0000 - val_loss: 0.1346 - val_precision: 0.0374 - val_recall: 0.9072 - val_tn: 54602.0000 - val_tp: 88.0000
## Epoch 4/30
## 112/112 - 0s - 2ms/step - fn: 25.0000 - fp: 5976.0000 - loss: 9.9649e-07 - precision: 0.0583 - recall: 0.9367 - tn: 221475.0000 - tp: 370.0000 - val_fn: 9.0000 - val_fp: 1593.0000 - val_loss: 0.0781 - val_precision: 0.0523 - val_recall: 0.9072 - val_tn: 55271.0000 - val_tp: 88.0000
## Epoch 5/30
## 112/112 - 0s - 2ms/step - fn: 22.0000 - fp: 7713.0000 - loss: 9.5664e-07 - precision: 0.0461 - recall: 0.9443 - tn: 219738.0000 - tp: 373.0000 - val_fn: 10.0000 - val_fp: 613.0000 - val_loss: 0.0444 - val_precision: 0.1243 - val_recall: 0.8969 - val_tn: 56251.0000 - val_tp: 87.0000
## Epoch 6/30
## 112/112 - 0s - 2ms/step - fn: 18.0000 - fp: 8517.0000 - loss: 9.6051e-07 - precision: 0.0424 - recall: 0.9544 - tn: 218934.0000 - tp: 377.0000 - val_fn: 12.0000 - val_fp: 1033.0000 - val_loss: 0.0643 - val_precision: 0.0760 - val_recall: 0.8763 - val_tn: 55831.0000 - val_tp: 85.0000
## Epoch 7/30
## 112/112 - 0s - 2ms/step - fn: 17.0000 - fp: 6409.0000 - loss: 7.3375e-07 - precision: 0.0557 - recall: 0.9570 - tn: 221042.0000 - tp: 378.0000 - val_fn: 10.0000 - val_fp: 1549.0000 - val_loss: 0.0697 - val_precision: 0.0532 - val_recall: 0.8969 - val_tn: 55315.0000 - val_tp: 87.0000
## Epoch 8/30
## 112/112 - 0s - 2ms/step - fn: 14.0000 - fp: 6080.0000 - loss: 6.5853e-07 - precision: 0.0590 - recall: 0.9646 - tn: 221371.0000 - tp: 381.0000 - val_fn: 10.0000 - val_fp: 2677.0000 - val_loss: 0.1178 - val_precision: 0.0315 - val_recall: 0.8969 - val_tn: 54187.0000 - val_tp: 87.0000
## Epoch 9/30
## 112/112 - 0s - 2ms/step - fn: 11.0000 - fp: 6592.0000 - loss: 5.9467e-07 - precision: 0.0550 - recall: 0.9722 - tn: 220859.0000 - tp: 384.0000 - val_fn: 9.0000 - val_fp: 2094.0000 - val_loss: 0.0806 - val_precision: 0.0403 - val_recall: 0.9072 - val_tn: 54770.0000 - val_tp: 88.0000
## Epoch 10/30
## 112/112 - 0s - 2ms/step - fn: 7.0000 - fp: 6334.0000 - loss: 5.3129e-07 - precision: 0.0577 - recall: 0.9823 - tn: 221117.0000 - tp: 388.0000 - val_fn: 10.0000 - val_fp: 850.0000 - val_loss: 0.0396 - val_precision: 0.0928 - val_recall: 0.8969 - val_tn: 56014.0000 - val_tp: 87.0000
## Epoch 11/30
## 112/112 - 0s - 2ms/step - fn: 8.0000 - fp: 5922.0000 - loss: 5.0009e-07 - precision: 0.0613 - recall: 0.9797 - tn: 221529.0000 - tp: 387.0000 - val_fn: 12.0000 - val_fp: 1488.0000 - val_loss: 0.0506 - val_precision: 0.0540 - val_recall: 0.8763 - val_tn: 55376.0000 - val_tp: 85.0000
## Epoch 12/30
## 112/112 - 0s - 2ms/step - fn: 8.0000 - fp: 7243.0000 - loss: 5.7125e-07 - precision: 0.0507 - recall: 0.9797 - tn: 220208.0000 - tp: 387.0000 - val_fn: 11.0000 - val_fp: 1014.0000 - val_loss: 0.0414 - val_precision: 0.0782 - val_recall: 0.8866 - val_tn: 55850.0000 - val_tp: 86.0000
## Epoch 13/30
## 112/112 - 0s - 2ms/step - fn: 6.0000 - fp: 4141.0000 - loss: 3.7800e-07 - precision: 0.0859 - recall: 0.9848 - tn: 223310.0000 - tp: 389.0000 - val_fn: 8.0000 - val_fp: 2491.0000 - val_loss: 0.0986 - val_precision: 0.0345 - val_recall: 0.9175 - val_tn: 54373.0000 - val_tp: 89.0000
## Epoch 14/30
## 112/112 - 0s - 2ms/step - fn: 6.0000 - fp: 6118.0000 - loss: 5.4537e-07 - precision: 0.0598 - recall: 0.9848 - tn: 221333.0000 - tp: 389.0000 - val_fn: 9.0000 - val_fp: 2488.0000 - val_loss: 0.1019 - val_precision: 0.0342 - val_recall: 0.9072 - val_tn: 54376.0000 - val_tp: 88.0000
## Epoch 15/30
## 112/112 - 0s - 2ms/step - fn: 4.0000 - fp: 4646.0000 - loss: 4.0941e-07 - precision: 0.0776 - recall: 0.9899 - tn: 222805.0000 - tp: 391.0000 - val_fn: 8.0000 - val_fp: 2949.0000 - val_loss: 0.1263 - val_precision: 0.0293 - val_recall: 0.9175 - val_tn: 53915.0000 - val_tp: 89.0000
## Epoch 16/30
## 112/112 - 0s - 2ms/step - fn: 10.0000 - fp: 7978.0000 - loss: 8.1907e-07 - precision: 0.0460 - recall: 0.9747 - tn: 219473.0000 - tp: 385.0000 - val_fn: 8.0000 - val_fp: 3430.0000 - val_loss: 0.2915 - val_precision: 0.0253 - val_recall: 0.9175 - val_tn: 53434.0000 - val_tp: 89.0000
## Epoch 17/30
## 112/112 - 0s - 2ms/step - fn: 23.0000 - fp: 12673.0000 - loss: 3.1681e-06 - precision: 0.0285 - recall: 0.9418 - tn: 214778.0000 - tp: 372.0000 - val_fn: 12.0000 - val_fp: 3416.0000 - val_loss: 0.7861 - val_precision: 0.0243 - val_recall: 0.8763 - val_tn: 53448.0000 - val_tp: 85.0000
## Epoch 18/30
## 112/112 - 0s - 2ms/step - fn: 23.0000 - fp: 11244.0000 - loss: 2.3250e-06 - precision: 0.0320 - recall: 0.9418 - tn: 216207.0000 - tp: 372.0000 - val_fn: 11.0000 - val_fp: 1717.0000 - val_loss: 0.1895 - val_precision: 0.0477 - val_recall: 0.8866 - val_tn: 55147.0000 - val_tp: 86.0000
## Epoch 19/30
## 112/112 - 0s - 2ms/step - fn: 16.0000 - fp: 8824.0000 - loss: 1.3502e-06 - precision: 0.0412 - recall: 0.9595 - tn: 218627.0000 - tp: 379.0000 - val_fn: 6.0000 - val_fp: 5161.0000 - val_loss: 0.1947 - val_precision: 0.0173 - val_recall: 0.9381 - val_tn: 51703.0000 - val_tp: 91.0000
## Epoch 20/30
## 112/112 - 0s - 2ms/step - fn: 10.0000 - fp: 8621.0000 - loss: 1.0916e-06 - precision: 0.0427 - recall: 0.9747 - tn: 218830.0000 - tp: 385.0000 - val_fn: 12.0000 - val_fp: 1107.0000 - val_loss: 0.1877 - val_precision: 0.0713 - val_recall: 0.8763 - val_tn: 55757.0000 - val_tp: 85.0000
## Epoch 21/30
## 112/112 - 0s - 2ms/step - fn: 10.0000 - fp: 7096.0000 - loss: 1.0528e-06 - precision: 0.0515 - recall: 0.9747 - tn: 220355.0000 - tp: 385.0000 - val_fn: 11.0000 - val_fp: 1099.0000 - val_loss: 0.0539 - val_precision: 0.0726 - val_recall: 0.8866 - val_tn: 55765.0000 - val_tp: 86.0000
## Epoch 22/30
## 112/112 - 0s - 3ms/step - fn: 11.0000 - fp: 5263.0000 - loss: 5.9445e-07 - precision: 0.0680 - recall: 0.9722 - tn: 222188.0000 - tp: 384.0000 - val_fn: 10.0000 - val_fp: 1399.0000 - val_loss: 0.0632 - val_precision: 0.0585 - val_recall: 0.8969 - val_tn: 55465.0000 - val_tp: 87.0000
## Epoch 23/30
## 112/112 - 0s - 4ms/step - fn: 8.0000 - fp: 5412.0000 - loss: 6.1044e-07 - precision: 0.0667 - recall: 0.9797 - tn: 222039.0000 - tp: 387.0000 - val_fn: 12.0000 - val_fp: 873.0000 - val_loss: 0.0727 - val_precision: 0.0887 - val_recall: 0.8763 - val_tn: 55991.0000 - val_tp: 85.0000
## Epoch 24/30
## 112/112 - 0s - 3ms/step - fn: 5.0000 - fp: 3469.0000 - loss: 4.1501e-07 - precision: 0.1011 - recall: 0.9873 - tn: 223982.0000 - tp: 390.0000 - val_fn: 11.0000 - val_fp: 1369.0000 - val_loss: 0.0758 - val_precision: 0.0591 - val_recall: 0.8866 - val_tn: 55495.0000 - val_tp: 86.0000
## Epoch 25/30
## 112/112 - 0s - 2ms/step - fn: 7.0000 - fp: 4489.0000 - loss: 6.9404e-07 - precision: 0.0796 - recall: 0.9823 - tn: 222962.0000 - tp: 388.0000 - val_fn: 11.0000 - val_fp: 1580.0000 - val_loss: 0.1499 - val_precision: 0.0516 - val_recall: 0.8866 - val_tn: 55284.0000 - val_tp: 86.0000
## Epoch 26/30
## 112/112 - 0s - 2ms/step - fn: 4.0000 - fp: 3469.0000 - loss: 4.4795e-07 - precision: 0.1013 - recall: 0.9899 - tn: 223982.0000 - tp: 391.0000 - val_fn: 12.0000 - val_fp: 951.0000 - val_loss: 0.0502 - val_precision: 0.0820 - val_recall: 0.8763 - val_tn: 55913.0000 - val_tp: 85.0000
## Epoch 27/30
## 112/112 - 0s - 2ms/step - fn: 5.0000 - fp: 4854.0000 - loss: 7.1714e-07 - precision: 0.0744 - recall: 0.9873 - tn: 222597.0000 - tp: 390.0000 - val_fn: 13.0000 - val_fp: 1354.0000 - val_loss: 0.0818 - val_precision: 0.0584 - val_recall: 0.8660 - val_tn: 55510.0000 - val_tp: 84.0000
## Epoch 28/30
## 112/112 - 0s - 3ms/step - fn: 6.0000 - fp: 4904.0000 - loss: 4.4107e-07 - precision: 0.0735 - recall: 0.9848 - tn: 222547.0000 - tp: 389.0000 - val_fn: 12.0000 - val_fp: 651.0000 - val_loss: 0.0420 - val_precision: 0.1155 - val_recall: 0.8763 - val_tn: 56213.0000 - val_tp: 85.0000
## Epoch 29/30
## 112/112 - 0s - 3ms/step - fn: 4.0000 - fp: 3909.0000 - loss: 4.2965e-07 - precision: 0.0909 - recall: 0.9899 - tn: 223542.0000 - tp: 391.0000 - val_fn: 12.0000 - val_fp: 817.0000 - val_loss: 0.0405 - val_precision: 0.0942 - val_recall: 0.8763 - val_tn: 56047.0000 - val_tp: 85.0000
## Epoch 30/30
## 112/112 - 0s - 2ms/step - fn: 2.0000 - fp: 2985.0000 - loss: 2.8492e-07 - precision: 0.1163 - recall: 0.9949 - tn: 224466.0000 - tp: 393.0000 - val_fn: 12.0000 - val_fp: 1040.0000 - val_loss: 0.0593 - val_precision: 0.0756 - val_recall: 0.8763 - val_tn: 55824.0000 - val_tp: 85.0000
val_pred <- model %>%
  predict(val_features) %>%
  { as.integer(. > 0.5) }
## 1781/1781 - 1s - 436us/step
pred_correct <- val_df$Class == val_pred
cat(sprintf("Validation accuracy: %.2f", mean(pred_correct)))
## Validation accuracy: 0.98
fraudulent <- val_df$Class == 1

n_fraudulent_detected <- sum(fraudulent & pred_correct)
n_fraudulent_missed <- sum(fraudulent & !pred_correct)
n_legitimate_flagged <- sum(!fraudulent & !pred_correct)

Conclusions

At the end of training, out of 56,961 validation transactions, we are:

  • Correctly identifying 85 of them as fraudulent
  • Missing 12 fraudulent transactions
  • At the cost of incorrectly flagging 1,040 legitimate transactions

In the real world, one would put an even higher weight on class 1, so as to reflect that False Negatives are more costly than False Positives.

Next time your credit card gets declined in an online purchase – this is why.