Structured data classification with FeatureSpace
Source:vignettes/examples/structured_data/structured_data_classification_with_feature_space.Rmd
structured_data_classification_with_feature_space.Rmd
Introduction
This example demonstrates how to do structured data classification
(also known as tabular data classification), starting from a raw CSV
file. Our data includes numerical features, and integer categorical
features, and string categorical features. We will use the utility
layer_feature_space()
to index, preprocess, and encode our
features.
The code is adapted from the example Structured
data classification from scratch. While the previous example managed
its own low-level feature preprocessing and encoding with Keras
preprocessing layers, in this example we delegate everything to
layer_feature_space()
, making the workflow extremely quick
and easy.
The dataset
Our dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It’s a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here’s the description of each feature:
Column | Description | Feature Type |
---|---|---|
Age | Age in years | Numerical |
Sex | (1 = male; 0 = female) | Categorical |
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical |
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical |
Chol | Serum cholesterol in mg/dl | Numerical |
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical |
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical |
Thalach | Maximum heart rate achieved | Numerical |
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical |
Oldpeak | ST depression induced by exercise relative to rest | Numerical |
Slope | Slope of the peak exercise ST segment | Numerical |
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical |
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical |
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target |
Setup
library(readr)
library(dplyr, warn.conflicts = FALSE)
library(keras3)
library(tensorflow, exclude = c("shape", "set_random_seed"))
library(tfdatasets, exclude = "shape")
conflicted::conflicts_prefer(
keras3::shape(),
keras3::set_random_seed(),
dplyr::filter(),
.quiet = TRUE
)
use_backend("tensorflow")
Preparing the data
Let’s download the data and load it into a Pandas dataframe:
file_url <-
"http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
df <- read_csv(file_url, col_types = cols(
oldpeak = col_double(),
thal = col_character(),
.default = col_integer()
))
# the dataset has two malformed rows, filter them out
df <- df |> filter(!thal %in% c("1", "2"))
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label)
glimpse(df)
## Rows: 301
## Columns: 14
## $ age <int> 63, 67, 67, 37, 41, 56, 62, 57, 63, 53, 57, 56, 56, 44, 5…
## $ sex <int> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, …
## $ cp <int> 1, 4, 4, 3, 2, 2, 4, 4, 4, 4, 4, 2, 3, 2, 3, 3, 2, 4, 3, …
## $ trestbps <int> 145, 160, 120, 130, 130, 120, 140, 120, 130, 140, 140, 14…
## $ chol <int> 233, 286, 229, 250, 204, 236, 268, 354, 254, 203, 192, 29…
## $ fbs <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, …
## $ restecg <int> 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, …
## $ thalach <int> 150, 108, 129, 187, 172, 178, 160, 163, 147, 155, 148, 15…
## $ exang <int> 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ oldpeak <dbl> 2.3, 1.5, 2.6, 3.5, 1.4, 0.8, 3.6, 0.6, 1.4, 3.1, 0.4, 1.…
## $ slope <int> 3, 2, 2, 3, 1, 1, 3, 1, 2, 3, 2, 2, 2, 1, 1, 1, 3, 1, 1, …
## $ ca <int> 0, 3, 2, 0, 0, 0, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ thal <chr> "fixed", "normal", "reversible", "normal", "normal", "nor…
## $ target <int> 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
Here’s a preview of a few samples:
df
## # A tibble: 301 × 14
## age sex cp trestbps chol fbs restecg thalach exang oldpeak
## <int> <int> <int> <int> <int> <int> <int> <int> <int> <dbl>
## 1 63 1 1 145 233 1 2 150 0 2.3
## 2 67 1 4 160 286 0 2 108 1 1.5
## 3 67 1 4 120 229 0 2 129 1 2.6
## 4 37 1 3 130 250 0 0 187 0 3.5
## 5 41 0 2 130 204 0 2 172 0 1.4
## 6 56 1 2 120 236 0 0 178 0 0.8
## 7 62 0 4 140 268 0 2 160 0 3.6
## 8 57 0 4 120 354 0 0 163 1 0.6
## 9 63 1 4 130 254 0 2 147 0 1.4
## 10 53 1 4 140 203 1 2 155 1 3.1
## # ℹ 291 more rows
## # ℹ 4 more variables: slope <int>, ca <int>, thal <chr>, target <int>
The last column, “target”, indicates whether the patient has a heart disease (1) or not (0).
Let’s split the data into a training and validation set:
val_idx <- nrow(df) %>% sample.int(., . * 0.2)
val_df <- df[val_idx, ]
train_df <- df[-val_idx, ]
cat(sprintf(
"Using %d samples for training and %d for validation",
nrow(train_df), nrow(val_df)
))
## Using 241 samples for training and 60 for validation
Let’s generate tf_dataset
objects for each
dataframe:
dataframe_to_dataset <- function(df) {
labels <- df |> pull(target) |> as.integer()
inputs <- df |> select(-target) |> as.list()
ds <- tensor_slices_dataset(list(inputs, labels)) |>
dataset_shuffle(nrow(df))
ds
}
train_ds <- dataframe_to_dataset(train_df)
val_ds <- dataframe_to_dataset(val_df)
Each tf_dataset
yields a tuple
(input, target)
where input
is a dictionary (a
named list) of features and target
is the value
0
or 1
:
## Input: List of 13
## $ age :<tf.Tensor: shape=(), dtype=int32, numpy=45>
## $ sex :<tf.Tensor: shape=(), dtype=int32, numpy=1>
## $ cp :<tf.Tensor: shape=(), dtype=int32, numpy=1>
## $ trestbps:<tf.Tensor: shape=(), dtype=int32, numpy=110>
## $ chol :<tf.Tensor: shape=(), dtype=int32, numpy=264>
## $ fbs :<tf.Tensor: shape=(), dtype=int32, numpy=0>
## $ restecg :<tf.Tensor: shape=(), dtype=int32, numpy=0>
## $ thalach :<tf.Tensor: shape=(), dtype=int32, numpy=132>
## $ exang :<tf.Tensor: shape=(), dtype=int32, numpy=0>
## $ oldpeak :<tf.Tensor: shape=(), dtype=float32, numpy=1.2>
## $ slope :<tf.Tensor: shape=(), dtype=int32, numpy=2>
## $ ca :<tf.Tensor: shape=(), dtype=int32, numpy=0>
## $ thal :<tf.Tensor: shape=(), dtype=string, numpy=b'reversible'>
## Target: <tf.Tensor: shape=(), dtype=int32, numpy=0>
Let’s batch the datasets:
train_ds <- train_ds |> dataset_batch(32)
val_ds <- val_ds |> dataset_batch(32)
Configuring a FeatureSpace
To configure how each feature should be preprocessed, we instantiate
a layer_feature_space()
, and we pass to it a dictionary
(named list with unique names) that maps the name of our features to a
string that describes the feature type.
We have a few “integer categorical” features such as
"FBS"
, one “string categorical” feature
("thal"
), and a few numerical features, which we’d like to
normalize – except "age"
, which we’d like to discretize
into a number of bins.
We also use the crosses
argument to capture feature
interactions for some categorical features, that is to say, create
additional features that represent value co-occurrences for these
categorical features. You can compute feature crosses like this for
arbitrary sets of categorical features – not just tuples of two
features. Because the resulting co-occurences are hashed into a
fixed-sized vector, you don’t need to worry about whether the
co-occurence space is too large.
feature_space <- layer_feature_space(
features = list(
# Categorical features encoded as integers
sex = "integer_categorical",
cp = "integer_categorical",
fbs = "integer_categorical",
restecg = "integer_categorical",
exang = "integer_categorical",
ca = "integer_categorical",
# Categorical feature encoded as string
thal = "string_categorical",
# Numerical features to discretize
age = "float_discretized",
# Numerical features to normalize
trestbps = "float_normalized",
chol = "float_normalized",
thalach = "float_normalized",
oldpeak = "float_normalized",
slope = "float_normalized"
),
# We create additional features by hashing
# value co-occurrences for the
# following groups of categorical features.
crosses = list(c("sex", "age"), c("thal", "ca")),
# The hashing space for these co-occurrences
# wil be 32-dimensional.
crossing_dim = 32,
# Our utility will one-hot encode all categorical
# features and concat all features into a single
# vector (one vector per sample).
output_mode = "concat"
)
Further customizing a FeatureSpace
Specifying the feature type via a string name is quick and easy, but
sometimes you may want to further configure the preprocessing of each
feature. For instance, in our case, our categorical features don’t have
a large set of possible values – it’s only a handful of values per
feature (e.g. 1
and 0
for the feature
"FBS"
), and all possible values are represented in the
training set. As a result, we don’t need to reserve an index to
represent “out of vocabulary” values for these features – which would
have been the default behavior. Below, we just specify
num_oov_indices=0
in each of these features to tell the
feature preprocessor to skip “out of vocabulary” indexing.
Other customizations you have access to include specifying the number
of bins for discretizing features of type
"float_discretized"
, or the dimensionality of the hashing
space for feature crossing.
feature_space <- layer_feature_space(
features = list(
# Categorical features encoded as integers
sex = feature_integer_categorical(num_oov_indices = 0),
cp = feature_integer_categorical(num_oov_indices = 0),
fbs = feature_integer_categorical(num_oov_indices = 0),
restecg = feature_integer_categorical(num_oov_indices = 0),
exang = feature_integer_categorical(num_oov_indices = 0),
ca = feature_integer_categorical(num_oov_indices = 0),
# Categorical feature encoded as string
thal = feature_string_categorical(num_oov_indices = 0),
# Numerical features to discretize
age = feature_float_discretized(num_bins = 30),
# Numerical features to normalize
trestbps = feature_float_normalized(),
chol = feature_float_normalized(),
thalach = feature_float_normalized(),
oldpeak = feature_float_normalized(),
slope = feature_float_normalized()
),
# Specify feature cross with a custom crossing dim.
crosses = list(
feature_cross(
feature_names = c("sex", "age"),
crossing_dim = 64
),
feature_cross(
feature_names = c("thal", "ca"),
crossing_dim = 16
)
),
output_mode = "concat"
)
Adapt the FeatureSpace
to the training data
Before we start using the FeatureSpace
to build a model,
we have to adapt it to the training data. During adapt()
,
the FeatureSpace
will:
- Index the set of possible values for categorical features.
- Compute the mean and variance for numerical features to normalize.
- Compute the value boundaries for the different bins for numerical features to discretize.
Note that adapt()
should be called on a
tf_dataset
which yields dicts (named lists) of feature
values – no labels.
train_ds_with_no_labels <- train_ds |> dataset_map(\(x, y) x)
feature_space |> adapt(train_ds_with_no_labels)
At this point, the FeatureSpace
can be called on a dict
of raw feature values, and will return a single concatenate vector for
each sample, combining encoded features and feature crosses.
c(x, y) %<-% iter_next(as_iterator(train_ds))
preprocessed_x <- feature_space(x)
preprocessed_x
## tf.Tensor(
## [[0. 0. 0. ... 0. 0. 0.]
## [0. 0. 0. ... 0. 0. 0.]
## [0. 0. 0. ... 0. 0. 0.]
## ...
## [0. 0. 0. ... 0. 0. 0.]
## [0. 0. 0. ... 0. 0. 0.]
## [0. 0. 0. ... 0. 0. 0.]], shape=(32, 136), dtype=float32)
Two ways to manage preprocessing: as part of the
tf.data
pipeline, or in the model itself
There are two ways in which you can leverage your
FeatureSpace
:
Asynchronous preprocessing in tf.data
You can make it part of your data pipeline, before the model. This enables asynchronous parallel preprocessing of the data on CPU before it hits the model. Do this if you’re training on GPU or TPU, or if you want to speed up preprocessing. Usually, this is always the right thing to do during training.
Synchronous preprocessing in the model
You can make it part of your model. This means that the model will expect dicts of raw feature values, and the preprocessing batch will be done synchronously (in a blocking manner) before the rest of the forward pass. Do this if you want to have an end-to-end model that can process raw feature values – but keep in mind that your model will only be able to run on CPU, since most types of feature preprocessing (e.g. string preprocessing) are not GPU or TPU compatible.
Do not do this on GPU / TPU or in performance-sensitive settings. In general, you want to do in-model preprocessing when you do inference on CPU.
In our case, we will apply the FeatureSpace
in the
tf.data pipeline during training, but we will do inference with an
end-to-end model that includes the FeatureSpace
.
Let’s create a training and validation dataset of preprocessed batches:
preprocessed_train_ds <- train_ds |>
dataset_map(\(x, y) list(feature_space(x), y),
num_parallel_calls = tf$data$AUTOTUNE) |>
dataset_prefetch(tf$data$AUTOTUNE)
preprocessed_val_ds <- val_ds |>
dataset_map(\(x, y) list(feature_space(x), y),
num_parallel_calls = tf$data$AUTOTUNE) |>
dataset_prefetch(tf$data$AUTOTUNE)
Build a model
Time to build a model – or rather two models:
- A training model that expects preprocessed features (one sample = one vector)
- An inference model that expects raw features (one sample = dict of raw feature values)
dict_inputs <- feature_space$get_inputs()
encoded_features <- feature_space$get_encoded_features()
predictions <- encoded_features |>
layer_dense(32, activation="relu") |>
layer_dropout(0.5) |>
layer_dense(1, activation="sigmoid")
training_model <- keras_model(inputs = encoded_features,
outputs = predictions)
training_model |> compile(optimizer = "adam",
loss = "binary_crossentropy",
metrics = "accuracy")
inference_model <- keras_model(inputs = dict_inputs,
outputs = predictions)
Train the model
Let’s train our model for 20 epochs. Note that feature preprocessing is happening as part of the tf.data pipeline, not as part of the model.
training_model |> fit(
preprocessed_train_ds,
epochs = 20,
validation_data = preprocessed_val_ds,
verbose = 2
)
## Epoch 1/20
## 8/8 - 2s - 280ms/step - accuracy: 0.4689 - loss: 0.7471 - val_accuracy: 0.5167 - val_loss: 0.7019
## Epoch 2/20
## 8/8 - 1s - 140ms/step - accuracy: 0.5602 - loss: 0.6785 - val_accuracy: 0.6333 - val_loss: 0.6491
## Epoch 3/20
## 8/8 - 0s - 46ms/step - accuracy: 0.6307 - loss: 0.6478 - val_accuracy: 0.7000 - val_loss: 0.6053
## Epoch 4/20
## 8/8 - 0s - 12ms/step - accuracy: 0.6432 - loss: 0.6246 - val_accuracy: 0.7667 - val_loss: 0.5692
## Epoch 5/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7178 - loss: 0.5813 - val_accuracy: 0.7667 - val_loss: 0.5359
## Epoch 6/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7344 - loss: 0.5371 - val_accuracy: 0.7833 - val_loss: 0.5067
## Epoch 7/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7884 - loss: 0.5158 - val_accuracy: 0.8333 - val_loss: 0.4810
## Epoch 8/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7759 - loss: 0.5011 - val_accuracy: 0.8500 - val_loss: 0.4569
## Epoch 9/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7676 - loss: 0.4865 - val_accuracy: 0.8500 - val_loss: 0.4354
## Epoch 10/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7925 - loss: 0.4601 - val_accuracy: 0.8333 - val_loss: 0.4161
## Epoch 11/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7967 - loss: 0.4617 - val_accuracy: 0.8667 - val_loss: 0.3976
## Epoch 12/20
## 8/8 - 0s - 13ms/step - accuracy: 0.7967 - loss: 0.4316 - val_accuracy: 0.8667 - val_loss: 0.3796
## Epoch 13/20
## 8/8 - 0s - 13ms/step - accuracy: 0.8506 - loss: 0.4058 - val_accuracy: 0.8833 - val_loss: 0.3643
## Epoch 14/20
## 8/8 - 0s - 13ms/step - accuracy: 0.8174 - loss: 0.4197 - val_accuracy: 0.8833 - val_loss: 0.3510
## Epoch 15/20
## 8/8 - 0s - 14ms/step - accuracy: 0.8299 - loss: 0.3888 - val_accuracy: 0.8833 - val_loss: 0.3405
## Epoch 16/20
## 8/8 - 0s - 13ms/step - accuracy: 0.8257 - loss: 0.3820 - val_accuracy: 0.8833 - val_loss: 0.3294
## Epoch 17/20
## 8/8 - 0s - 13ms/step - accuracy: 0.8299 - loss: 0.3746 - val_accuracy: 0.8833 - val_loss: 0.3223
## Epoch 18/20
## 8/8 - 0s - 13ms/step - accuracy: 0.8506 - loss: 0.3487 - val_accuracy: 0.8833 - val_loss: 0.3153
## Epoch 19/20
## 8/8 - 0s - 14ms/step - accuracy: 0.8465 - loss: 0.3558 - val_accuracy: 0.8667 - val_loss: 0.3093
## Epoch 20/20
## 8/8 - 0s - 14ms/step - accuracy: 0.8672 - loss: 0.3570 - val_accuracy: 0.8667 - val_loss: 0.3036
We quickly get to 80% validation accuracy.
Inference on new data with the end-to-end model
Now, we can use our inference model (which includes the
FeatureSpace
) to make predictions based on dicts of raw
features values, as follows:
sample <- list(
age = 60,
sex = 1,
cp = 1,
trestbps = 145,
chol = 233,
fbs = 1,
restecg = 2,
thalach = 150,
exang = 0,
oldpeak = 2.3,
slope = 3,
ca = 0,
thal = "fixed"
)
input_dict <- lapply(sample, \(x) op_convert_to_tensor(array(x)))
predictions <- inference_model |> predict(input_dict)
## 1/1 - 0s - 394ms/step
glue::glue(r"---(
This particular patient had a {(100 * predictions) |> signif(3)}% probability
of having a heart disease, as evaluated by our model.
)---")
## This particular patient had a 44.8% probability
## of having a heart disease, as evaluated by our model.