Skip to contents

Download the data

options(timeout = 5000)
download.file(
  "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
  "datasets/images.tar.gz"
)
download.file(
  "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
  "datasets/annotations.tar.gz"
)

untar("datasets/images.tar.gz", exdir = "datasets")
untar("datasets/annotations.tar.gz", exdir = "datasets")

Prepare paths of input images and target segmentation masks

library(keras3)
input_dir <- "datasets/images/"
target_dir <- "datasets/annotations/trimaps/"
img_size <- c(160, 160)
num_classes <- 3
batch_size <- 32

input_img_paths <- fs::dir_ls(input_dir, glob = "*.jpg") |> sort()
target_img_paths <- fs::dir_ls(target_dir, glob = "*.png") |> sort()

cat("Number of samples:", length(input_img_paths), "\n")
## Number of samples: 7390
for (i in 1:10) {
  cat(input_img_paths[i], "|", target_img_paths[i], "\n")
}
## datasets/images/Abyssinian_1.jpg | datasets/annotations/trimaps/Abyssinian_1.png
## datasets/images/Abyssinian_10.jpg | datasets/annotations/trimaps/Abyssinian_10.png
## datasets/images/Abyssinian_100.jpg | datasets/annotations/trimaps/Abyssinian_100.png
## datasets/images/Abyssinian_101.jpg | datasets/annotations/trimaps/Abyssinian_101.png
## datasets/images/Abyssinian_102.jpg | datasets/annotations/trimaps/Abyssinian_102.png
## datasets/images/Abyssinian_103.jpg | datasets/annotations/trimaps/Abyssinian_103.png
## datasets/images/Abyssinian_104.jpg | datasets/annotations/trimaps/Abyssinian_104.png
## datasets/images/Abyssinian_105.jpg | datasets/annotations/trimaps/Abyssinian_105.png
## datasets/images/Abyssinian_106.jpg | datasets/annotations/trimaps/Abyssinian_106.png
## datasets/images/Abyssinian_107.jpg | datasets/annotations/trimaps/Abyssinian_107.png

What does one input image and corresponding segmentation mask look like?

# Display input image #10
input_img_paths[10] |>
  jpeg::readJPEG() |>
  as.raster() |>
  plot()
plot of chunk unnamed-chunk-4
plot of chunk unnamed-chunk-4
target_img_paths[10] |>
  png::readPNG() |>
  magrittr::multiply_by(255)|>
  as.raster(max = 3) |>
  plot()
plot of chunk unnamed-chunk-4
plot of chunk unnamed-chunk-4

Prepare dataset to load & vectorize batches of data

library(tensorflow, exclude = c("shape", "set_random_seed"))
library(tfdatasets, exclude = "shape")


# Returns a tf_dataset
get_dataset <- function(batch_size, img_size, input_img_paths, target_img_paths,
                        max_dataset_len = NULL) {

  img_size <- as.integer(img_size)

  load_img_masks <- function(input_img_path, target_img_path) {
    input_img <- input_img_path |>
      tf$io$read_file() |>
      tf$io$decode_jpeg(channels = 3) |>
      tf$image$resize(img_size) |>
      tf$image$convert_image_dtype("float32")

    target_img <- target_img_path |>
      tf$io$read_file() |>
      tf$io$decode_png(channels = 1) |>
      tf$image$resize(img_size, method = "nearest") |>
      tf$image$convert_image_dtype("uint8")

    # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
    target_img <- target_img - 1L

    list(input_img, target_img)
  }

  if (!is.null(max_dataset_len)) {
    input_img_paths <- input_img_paths[1:max_dataset_len]
    target_img_paths <- target_img_paths[1:max_dataset_len]
  }

  list(input_img_paths, target_img_paths) |>
    tensor_slices_dataset() |>
    dataset_map(load_img_masks, num_parallel_calls = tf$data$AUTOTUNE)|>
    dataset_batch(batch_size)
}

Prepare U-Net Xception-style model

get_model <- function(img_size, num_classes) {

  inputs <- keras_input(shape = c(img_size, 3))

  ### [First half of the network: downsampling inputs] ###

  # Entry block
  x <- inputs |>
    layer_conv_2d(filters = 32, kernel_size = 3, strides = 2, padding = "same") |>
    layer_batch_normalization() |>
    layer_activation("relu")

  previous_block_activation <- x  # Set aside residual

  for (filters in c(64, 128, 256)) {
    x <- x |>
      layer_activation("relu") |>
      layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_activation("relu") |>
      layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_max_pooling_2d(pool_size = 3, strides = 2, padding = "same")

    residual <- previous_block_activation |>
      layer_conv_2d(filters = filters, kernel_size = 1, strides = 2, padding = "same")

    x <- layer_add(x, residual)  # Add back residual
    previous_block_activation <- x  # Set aside next residual
  }

  ### [Second half of the network: upsampling inputs] ###

  for (filters in c(256, 128, 64, 32)) {
    x <- x |>
      layer_activation("relu") |>
      layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_activation("relu") |>
      layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_upsampling_2d(size = 2)

    # Project residual
    residual <- previous_block_activation |>
      layer_upsampling_2d(size = 2) |>
      layer_conv_2d(filters = filters, kernel_size = 1, padding = "same")

    x <- layer_add(x, residual)     # Add back residual
    previous_block_activation <- x  # Set aside next residual
  }

  # Add a per-pixel classification layer
  outputs <- x |>
    layer_conv_2d(num_classes, 3, activation = "softmax", padding = "same")

  # Define the model
  keras_model(inputs, outputs)
}

# Build model
model <- get_model(img_size, num_classes)
summary(model)
## Model: "functional"
## ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓
## ┃ Layer (type)       Output Shape       Param #  Connected to    Trai… 
## ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩
## │ input_layer       │ (None, 160,     │         0 │ -              │   -
## │ (InputLayer)      │ 160, 3)         │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d (Conv2D)   │ (None, 80, 80,  │       896 │ input_layer[0… │   Y
## │                   │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80,  │       128 │ conv2d[0][0]   │   Y
## │ (BatchNormalizat…32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation        │ (None, 80, 80,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_1      │ (None, 80, 80,  │         0 │ activation[0]… │   -
## │ (Activation)      │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d  │ (None, 80, 80,  │     2,400 │ activation_1[Y
## │ (SeparableConv2D) │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80,  │       256 │ separable_con… │   Y
## │ (BatchNormalizat…64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_2      │ (None, 80, 80,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 80, 80,  │     4,736 │ activation_2[Y
## │ (SeparableConv2D) │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80,  │       256 │ separable_con… │   Y
## │ (BatchNormalizat…64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d     │ (None, 40, 40,  │         0 │ batch_normali… │   -
## │ (MaxPooling2D)    │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_1 (Conv2D) │ (None, 40, 40,  │     2,112 │ activation[0]… │   Y
## │                   │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add (Add)         │ (None, 40, 40,  │         0 │ max_pooling2d… │   -
## │                   │ 64)             │           │ conv2d_1[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_3      │ (None, 40, 40,  │         0 │ add[0][0]      │   -
## │ (Activation)      │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 40, 40,  │     8,896 │ activation_3[Y
## │ (SeparableConv2D) │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40,  │       512 │ separable_con… │   Y
## │ (BatchNormalizat…128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_4      │ (None, 40, 40,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 40, 40,  │    17,664 │ activation_4[Y
## │ (SeparableConv2D) │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40,  │       512 │ separable_con… │   Y
## │ (BatchNormalizat…128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d_1   │ (None, 20, 20,  │         0 │ batch_normali… │   -
## │ (MaxPooling2D)    │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_2 (Conv2D) │ (None, 20, 20,  │     8,320 │ add[0][0]      │   Y
## │                   │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_1 (Add)       │ (None, 20, 20,  │         0 │ max_pooling2d… │   -
## │                   │ 128)            │           │ conv2d_2[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_5      │ (None, 20, 20,  │         0 │ add_1[0][0]    │   -
## │ (Activation)      │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 20, 20,  │    34,176 │ activation_5[Y
## │ (SeparableConv2D) │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20,  │     1,024 │ separable_con… │   Y
## │ (BatchNormalizat…256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_6      │ (None, 20, 20,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 20, 20,  │    68,096 │ activation_6[Y
## │ (SeparableConv2D) │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20,  │     1,024 │ separable_con… │   Y
## │ (BatchNormalizat…256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d_2   │ (None, 10, 10,  │         0 │ batch_normali… │   -
## │ (MaxPooling2D)    │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_3 (Conv2D) │ (None, 10, 10,  │    33,024 │ add_1[0][0]    │   Y
## │                   │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_2 (Add)       │ (None, 10, 10,  │         0 │ max_pooling2d… │   -
## │                   │ 256)            │           │ conv2d_3[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_7      │ (None, 10, 10,  │         0 │ add_2[0][0]    │   -
## │ (Activation)      │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose  │ (None, 10, 10,  │   590,080 │ activation_7[Y
## │ (Conv2DTranspose) │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 10, 10,  │     1,024 │ conv2d_transp… │   Y
## │ (BatchNormalizat…256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_8      │ (None, 10, 10,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 10, 10,  │   590,080 │ activation_8[Y
## │ (Conv2DTranspose) │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 10, 10,  │     1,024 │ conv2d_transp… │   Y
## │ (BatchNormalizat…256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_1   │ (None, 20, 20,  │         0 │ add_2[0][0]    │   -
## │ (UpSampling2D)    │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d     │ (None, 20, 20,  │         0 │ batch_normali… │   -
## │ (UpSampling2D)    │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_4 (Conv2D) │ (None, 20, 20,  │    65,792 │ up_sampling2d… │   Y
## │                   │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_3 (Add)       │ (None, 20, 20,  │         0 │ up_sampling2d… │   -
## │                   │ 256)            │           │ conv2d_4[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_9      │ (None, 20, 20,  │         0 │ add_3[0][0]    │   -
## │ (Activation)      │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 20, 20,  │   295,040 │ activation_9[Y
## │ (Conv2DTranspose) │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20,  │       512 │ conv2d_transp… │   Y
## │ (BatchNormalizat…128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_10     │ (None, 20, 20,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 20, 20,  │   147,584 │ activation_10… │   Y
## │ (Conv2DTranspose) │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20,  │       512 │ conv2d_transp… │   Y
## │ (BatchNormalizat…128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_3   │ (None, 40, 40,  │         0 │ add_3[0][0]    │   -
## │ (UpSampling2D)    │ 256)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_2   │ (None, 40, 40,  │         0 │ batch_normali… │   -
## │ (UpSampling2D)    │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_5 (Conv2D) │ (None, 40, 40,  │    32,896 │ up_sampling2d… │   Y
## │                   │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_4 (Add)       │ (None, 40, 40,  │         0 │ up_sampling2d… │   -
## │                   │ 128)            │           │ conv2d_5[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_11     │ (None, 40, 40,  │         0 │ add_4[0][0]    │   -
## │ (Activation)      │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 40, 40,  │    73,792 │ activation_11… │   Y
## │ (Conv2DTranspose) │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40,  │       256 │ conv2d_transp… │   Y
## │ (BatchNormalizat…64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_12     │ (None, 40, 40,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 40, 40,  │    36,928 │ activation_12… │   Y
## │ (Conv2DTranspose) │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40,  │       256 │ conv2d_transp… │   Y
## │ (BatchNormalizat…64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_5   │ (None, 80, 80,  │         0 │ add_4[0][0]    │   -
## │ (UpSampling2D)    │ 128)            │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_4   │ (None, 80, 80,  │         0 │ batch_normali… │   -
## │ (UpSampling2D)    │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_6 (Conv2D) │ (None, 80, 80,  │     8,256 │ up_sampling2d… │   Y
## │                   │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_5 (Add)       │ (None, 80, 80,  │         0 │ up_sampling2d… │   -
## │                   │ 64)             │           │ conv2d_6[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_13     │ (None, 80, 80,  │         0 │ add_5[0][0]    │   -
## │ (Activation)      │ 64)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 80, 80,  │    18,464 │ activation_13… │   Y
## │ (Conv2DTranspose) │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80,  │       128 │ conv2d_transp… │   Y
## │ (BatchNormalizat…32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_14     │ (None, 80, 80,  │         0 │ batch_normali… │   -
## │ (Activation)      │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 80, 80,  │     9,248 │ activation_14… │   Y
## │ (Conv2DTranspose) │ 32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80,  │       128 │ conv2d_transp… │   Y
## │ (BatchNormalizat…32)             │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_7   │ (None, 160,     │         0 │ add_5[0][0]    │   -
## │ (UpSampling2D)    │ 160, 64)        │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_6   │ (None, 160,     │         0 │ batch_normali… │   -
## │ (UpSampling2D)    │ 160, 32)        │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_7 (Conv2D) │ (None, 160,     │     2,080 │ up_sampling2d… │   Y
## │                   │ 160, 32)        │           │                │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_6 (Add)       │ (None, 160,     │         0 │ up_sampling2d… │   -
## │                   │ 160, 32)        │           │ conv2d_7[0][0] │       │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_8 (Conv2D) │ (None, 160,     │       867 │ add_6[0][0]    │   Y
## │                   │ 160, 3)         │           │                │       │
## └───────────────────┴─────────────────┴───────────┴────────────────┴───────┘
##  Total params: 2,058,979 (7.85 MB)
##  Trainable params: 2,055,203 (7.84 MB)
##  Non-trainable params: 3,776 (14.75 KB)

Set aside a validation split

# Split our img paths into a training and a validation set
val_samples <- 1000
val_samples <- sample.int(length(input_img_paths), val_samples)

train_input_img_paths <- input_img_paths[-val_samples]
train_target_img_paths <- target_img_paths[-val_samples]

val_input_img_paths <- input_img_paths[val_samples]
val_target_img_paths <- target_img_paths[val_samples]

# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset <- get_dataset(
  batch_size,
  img_size,
  train_input_img_paths,
  train_target_img_paths,
  max_dataset_len = 1000
)
valid_dataset <- get_dataset(
  batch_size, img_size, val_input_img_paths, val_target_img_paths
)

Train the model

# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model |> compile(
  optimizer = optimizer_adam(1e-4),
  loss = "sparse_categorical_crossentropy"
)

callbacks <- list(
  callback_model_checkpoint(
    "models/oxford_segmentation.keras", save_best_only = TRUE
  )
)

# Train the model, doing validation at the end of each epoch.
epochs <- 50
model |> fit(
    train_dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=2
)
## Epoch 1/50
## 32/32 - 35s - 1s/step - loss: 1.4284 - val_loss: 1.5501
## Epoch 2/50
## 32/32 - 2s - 66ms/step - loss: 0.9222 - val_loss: 1.9888
## Epoch 3/50
## 32/32 - 2s - 69ms/step - loss: 0.7765 - val_loss: 2.5120
## Epoch 4/50
## 32/32 - 2s - 65ms/step - loss: 0.7201 - val_loss: 3.0113
## Epoch 5/50
## 32/32 - 2s - 66ms/step - loss: 0.6848 - val_loss: 3.2869
## Epoch 6/50
## 32/32 - 2s - 63ms/step - loss: 0.6557 - val_loss: 3.4469
## Epoch 7/50
## 32/32 - 2s - 67ms/step - loss: 0.6304 - val_loss: 3.5521
## Epoch 8/50
## 32/32 - 2s - 66ms/step - loss: 0.6084 - val_loss: 3.6420
## Epoch 9/50
## 32/32 - 2s - 63ms/step - loss: 0.5895 - val_loss: 3.7225
## Epoch 10/50
## 32/32 - 2s - 64ms/step - loss: 0.5727 - val_loss: 3.7848
## Epoch 11/50
## 32/32 - 2s - 63ms/step - loss: 0.5568 - val_loss: 3.8129
## Epoch 12/50
## 32/32 - 2s - 65ms/step - loss: 0.5409 - val_loss: 3.7908
## Epoch 13/50
## 32/32 - 2s - 68ms/step - loss: 0.5243 - val_loss: 3.7079
## Epoch 14/50
## 32/32 - 2s - 70ms/step - loss: 0.5063 - val_loss: 3.5739
## Epoch 15/50
## 32/32 - 2s - 67ms/step - loss: 0.4861 - val_loss: 3.4241
## Epoch 16/50
## 32/32 - 2s - 63ms/step - loss: 0.4638 - val_loss: 3.2322
## Epoch 17/50
## 32/32 - 2s - 64ms/step - loss: 0.4396 - val_loss: 2.9929
## Epoch 18/50
## 32/32 - 2s - 63ms/step - loss: 0.4140 - val_loss: 2.6925
## Epoch 19/50
## 32/32 - 2s - 64ms/step - loss: 0.3875 - val_loss: 2.3410
## Epoch 20/50
## 32/32 - 2s - 72ms/step - loss: 0.3619 - val_loss: 1.9768
## Epoch 21/50
## 32/32 - 2s - 68ms/step - loss: 0.3389 - val_loss: 1.6291
## Epoch 22/50
## 32/32 - 2s - 73ms/step - loss: 0.3201 - val_loss: 1.3334
## Epoch 23/50
## 32/32 - 2s - 70ms/step - loss: 0.3078 - val_loss: 1.0987
## Epoch 24/50
## 32/32 - 2s - 70ms/step - loss: 0.3073 - val_loss: 1.0320
## Epoch 25/50
## 32/32 - 2s - 68ms/step - loss: 0.3392 - val_loss: 0.9181
## Epoch 26/50
## 32/32 - 2s - 65ms/step - loss: 0.3674 - val_loss: 0.9856
## Epoch 27/50
## 32/32 - 2s - 72ms/step - loss: 0.3336 - val_loss: 0.8418
## Epoch 28/50
## 32/32 - 2s - 63ms/step - loss: 0.2901 - val_loss: 0.9893
## Epoch 29/50
## 32/32 - 2s - 65ms/step - loss: 0.2737 - val_loss: 1.1443
## Epoch 30/50
## 32/32 - 2s - 66ms/step - loss: 0.2681 - val_loss: 1.1886
## Epoch 31/50
## 32/32 - 2s - 67ms/step - loss: 0.2707 - val_loss: 1.1304
## Epoch 32/50
## 32/32 - 2s - 65ms/step - loss: 0.2860 - val_loss: 1.0745
## Epoch 33/50
## 32/32 - 2s - 65ms/step - loss: 0.3135 - val_loss: 1.2410
## Epoch 34/50
## 32/32 - 2s - 67ms/step - loss: 0.3016 - val_loss: 1.2617
## Epoch 35/50
## 32/32 - 2s - 65ms/step - loss: 0.2858 - val_loss: 1.0743
## Epoch 36/50
## 32/32 - 2s - 66ms/step - loss: 0.2803 - val_loss: 1.0980
## Epoch 37/50
## 32/32 - 2s - 69ms/step - loss: 0.2762 - val_loss: 1.4966
## Epoch 38/50
## 32/32 - 2s - 65ms/step - loss: 0.2625 - val_loss: 1.0397
## Epoch 39/50
## 32/32 - 2s - 65ms/step - loss: 0.2475 - val_loss: 1.1995
## Epoch 40/50
## 32/32 - 2s - 67ms/step - loss: 0.2468 - val_loss: 1.1267
## Epoch 41/50
## 32/32 - 2s - 65ms/step - loss: 0.2416 - val_loss: 1.3864
## Epoch 42/50
## 32/32 - 2s - 64ms/step - loss: 0.2304 - val_loss: 1.4553
## Epoch 43/50
## 32/32 - 2s - 65ms/step - loss: 0.2287 - val_loss: 1.2477
## Epoch 44/50
## 32/32 - 2s - 65ms/step - loss: 0.2223 - val_loss: 1.1645
## Epoch 45/50
## 32/32 - 2s - 67ms/step - loss: 0.2176 - val_loss: 1.0877
## Epoch 46/50
## 32/32 - 2s - 65ms/step - loss: 0.2109 - val_loss: 1.0843
## Epoch 47/50
## 32/32 - 2s - 70ms/step - loss: 0.2064 - val_loss: 1.0662
## Epoch 48/50
## 32/32 - 2s - 66ms/step - loss: 0.2004 - val_loss: 1.0927
## Epoch 49/50
## 32/32 - 2s - 64ms/step - loss: 0.1934 - val_loss: 1.0553
## Epoch 50/50
## 32/32 - 2s - 64ms/step - loss: 0.1839 - val_loss: 1.1265

Visualize predictions

model <- load_model("models/oxford_segmentation.keras")
# Generate predictions for all images in the validation set
val_dataset <- get_dataset(
  batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds <- predict(model, val_dataset)
## 32/32 - 3s - 94ms/step
display_mask <- function(i) {
  # Quick utility to display a model's prediction.
  mask <- val_preds[i,,,] %>%
    apply(c(1,2), which.max) %>%
    array_reshape(dim = c(img_size, 1))
  mask <- abind::abind(mask, mask, mask, along = 3)
  plot(as.raster(mask, max = 3))
}

# Display results for validation image #10
i <- 10

par(mfrow = c(1, 3))
# Display input image
input_img_paths[i] |>
  jpeg::readJPEG() |>
  as.raster() |>
  plot()

# Display ground-truth target mask
target_img_paths[i] |>
  png::readPNG() |>
  magrittr::multiply_by(255)|>
  as.raster(max = 3) |>
  plot()

# Display mask predicted by our model
display_mask(i)  # Note that the model only sees inputs at 150x150.
plot of chunk unnamed-chunk-9
plot of chunk unnamed-chunk-9