Prepare U-Net Xception-style model
get_model <- function(img_size, num_classes) {
inputs <- keras_input(shape = c(img_size, 3))
### [First half of the network: downsampling inputs] ###
# Entry block
x <- inputs |>
layer_conv_2d(filters = 32, kernel_size = 3, strides = 2, padding = "same") |>
layer_batch_normalization() |>
layer_activation("relu")
previous_block_activation <- x # Set aside residual
for (filters in c(64, 128, 256)) {
x <- x |>
layer_activation("relu") |>
layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
layer_batch_normalization() |>
layer_activation("relu") |>
layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
layer_batch_normalization() |>
layer_max_pooling_2d(pool_size = 3, strides = 2, padding = "same")
residual <- previous_block_activation |>
layer_conv_2d(filters = filters, kernel_size = 1, strides = 2, padding = "same")
x <- layer_add(x, residual) # Add back residual
previous_block_activation <- x # Set aside next residual
}
### [Second half of the network: upsampling inputs] ###
for (filters in c(256, 128, 64, 32)) {
x <- x |>
layer_activation("relu") |>
layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
layer_batch_normalization() |>
layer_activation("relu") |>
layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
layer_batch_normalization() |>
layer_upsampling_2d(size = 2)
# Project residual
residual <- previous_block_activation |>
layer_upsampling_2d(size = 2) |>
layer_conv_2d(filters = filters, kernel_size = 1, padding = "same")
x <- layer_add(x, residual) # Add back residual
previous_block_activation <- x # Set aside next residual
}
# Add a per-pixel classification layer
outputs <- x |>
layer_conv_2d(num_classes, 3, activation = "softmax", padding = "same")
# Define the model
keras_model(inputs, outputs)
}
# Build model
model <- get_model(img_size, num_classes)
summary(model)
## Model: "functional"
## ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓
## ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ Trai… ┃
## ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩
## │ input_layer │ (None, 160, │ 0 │ - │ - │
## │ (InputLayer) │ 160, 3) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d (Conv2D) │ (None, 80, 80, │ 896 │ input_layer[0… │ Y │
## │ │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80, │ 128 │ conv2d[0][0] │ Y │
## │ (BatchNormalizat… │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation │ (None, 80, 80, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_1 │ (None, 80, 80, │ 0 │ activation[0]… │ - │
## │ (Activation) │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d │ (None, 80, 80, │ 2,400 │ activation_1[… │ Y │
## │ (SeparableConv2D) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80, │ 256 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_2 │ (None, 80, 80, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 80, 80, │ 4,736 │ activation_2[… │ Y │
## │ (SeparableConv2D) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80, │ 256 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d │ (None, 40, 40, │ 0 │ batch_normali… │ - │
## │ (MaxPooling2D) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_1 (Conv2D) │ (None, 40, 40, │ 2,112 │ activation[0]… │ Y │
## │ │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add (Add) │ (None, 40, 40, │ 0 │ max_pooling2d… │ - │
## │ │ 64) │ │ conv2d_1[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_3 │ (None, 40, 40, │ 0 │ add[0][0] │ - │
## │ (Activation) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 40, 40, │ 8,896 │ activation_3[… │ Y │
## │ (SeparableConv2D) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40, │ 512 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_4 │ (None, 40, 40, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 40, 40, │ 17,664 │ activation_4[… │ Y │
## │ (SeparableConv2D) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40, │ 512 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d_1 │ (None, 20, 20, │ 0 │ batch_normali… │ - │
## │ (MaxPooling2D) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_2 (Conv2D) │ (None, 20, 20, │ 8,320 │ add[0][0] │ Y │
## │ │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_1 (Add) │ (None, 20, 20, │ 0 │ max_pooling2d… │ - │
## │ │ 128) │ │ conv2d_2[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_5 │ (None, 20, 20, │ 0 │ add_1[0][0] │ - │
## │ (Activation) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 20, 20, │ 34,176 │ activation_5[… │ Y │
## │ (SeparableConv2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20, │ 1,024 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_6 │ (None, 20, 20, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ separable_conv2d… │ (None, 20, 20, │ 68,096 │ activation_6[… │ Y │
## │ (SeparableConv2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20, │ 1,024 │ separable_con… │ Y │
## │ (BatchNormalizat… │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ max_pooling2d_2 │ (None, 10, 10, │ 0 │ batch_normali… │ - │
## │ (MaxPooling2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_3 (Conv2D) │ (None, 10, 10, │ 33,024 │ add_1[0][0] │ Y │
## │ │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_2 (Add) │ (None, 10, 10, │ 0 │ max_pooling2d… │ - │
## │ │ 256) │ │ conv2d_3[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_7 │ (None, 10, 10, │ 0 │ add_2[0][0] │ - │
## │ (Activation) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose │ (None, 10, 10, │ 590,080 │ activation_7[… │ Y │
## │ (Conv2DTranspose) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 10, 10, │ 1,024 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_8 │ (None, 10, 10, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 10, 10, │ 590,080 │ activation_8[… │ Y │
## │ (Conv2DTranspose) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 10, 10, │ 1,024 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_1 │ (None, 20, 20, │ 0 │ add_2[0][0] │ - │
## │ (UpSampling2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d │ (None, 20, 20, │ 0 │ batch_normali… │ - │
## │ (UpSampling2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_4 (Conv2D) │ (None, 20, 20, │ 65,792 │ up_sampling2d… │ Y │
## │ │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_3 (Add) │ (None, 20, 20, │ 0 │ up_sampling2d… │ - │
## │ │ 256) │ │ conv2d_4[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_9 │ (None, 20, 20, │ 0 │ add_3[0][0] │ - │
## │ (Activation) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 20, 20, │ 295,040 │ activation_9[… │ Y │
## │ (Conv2DTranspose) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20, │ 512 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_10 │ (None, 20, 20, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 20, 20, │ 147,584 │ activation_10… │ Y │
## │ (Conv2DTranspose) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 20, 20, │ 512 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_3 │ (None, 40, 40, │ 0 │ add_3[0][0] │ - │
## │ (UpSampling2D) │ 256) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_2 │ (None, 40, 40, │ 0 │ batch_normali… │ - │
## │ (UpSampling2D) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_5 (Conv2D) │ (None, 40, 40, │ 32,896 │ up_sampling2d… │ Y │
## │ │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_4 (Add) │ (None, 40, 40, │ 0 │ up_sampling2d… │ - │
## │ │ 128) │ │ conv2d_5[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_11 │ (None, 40, 40, │ 0 │ add_4[0][0] │ - │
## │ (Activation) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 40, 40, │ 73,792 │ activation_11… │ Y │
## │ (Conv2DTranspose) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40, │ 256 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_12 │ (None, 40, 40, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 40, 40, │ 36,928 │ activation_12… │ Y │
## │ (Conv2DTranspose) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 40, 40, │ 256 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_5 │ (None, 80, 80, │ 0 │ add_4[0][0] │ - │
## │ (UpSampling2D) │ 128) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_4 │ (None, 80, 80, │ 0 │ batch_normali… │ - │
## │ (UpSampling2D) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_6 (Conv2D) │ (None, 80, 80, │ 8,256 │ up_sampling2d… │ Y │
## │ │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_5 (Add) │ (None, 80, 80, │ 0 │ up_sampling2d… │ - │
## │ │ 64) │ │ conv2d_6[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_13 │ (None, 80, 80, │ 0 │ add_5[0][0] │ - │
## │ (Activation) │ 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 80, 80, │ 18,464 │ activation_13… │ Y │
## │ (Conv2DTranspose) │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80, │ 128 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ activation_14 │ (None, 80, 80, │ 0 │ batch_normali… │ - │
## │ (Activation) │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_transpose… │ (None, 80, 80, │ 9,248 │ activation_14… │ Y │
## │ (Conv2DTranspose) │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ batch_normalizat… │ (None, 80, 80, │ 128 │ conv2d_transp… │ Y │
## │ (BatchNormalizat… │ 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_7 │ (None, 160, │ 0 │ add_5[0][0] │ - │
## │ (UpSampling2D) │ 160, 64) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ up_sampling2d_6 │ (None, 160, │ 0 │ batch_normali… │ - │
## │ (UpSampling2D) │ 160, 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_7 (Conv2D) │ (None, 160, │ 2,080 │ up_sampling2d… │ Y │
## │ │ 160, 32) │ │ │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ add_6 (Add) │ (None, 160, │ 0 │ up_sampling2d… │ - │
## │ │ 160, 32) │ │ conv2d_7[0][0] │ │
## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤
## │ conv2d_8 (Conv2D) │ (None, 160, │ 867 │ add_6[0][0] │ Y │
## │ │ 160, 3) │ │ │ │
## └───────────────────┴─────────────────┴───────────┴────────────────┴───────┘
## Total params: 2,058,979 (7.85 MB)
## Trainable params: 2,055,203 (7.84 MB)
## Non-trainable params: 3,776 (14.75 KB)