Skip to contents

Freeze weights in a model or layer so that they are no longer trainable.

Usage

freeze_weights(object, from = NULL, to = NULL, which = NULL)

unfreeze_weights(object, from = NULL, to = NULL, which = NULL)

Arguments

object

Keras model or layer object

from

Layer instance, layer name, or layer index within model

to

Layer instance, layer name, or layer index within model

which

layer names, integer positions, layers, logical vector (of length(object$layers)), or a function returning a logical vector.

Value

The input object with frozen weights is returned, invisibly. Note, object is modified in place, and the return value is only provided to make usage with the pipe convenient.

Note

The from and to layer arguments are both inclusive.

When applied to a model, the freeze or unfreeze is a global operation over all layers in the model (i.e. layers not within the specified range will be set to the opposite value, e.g. unfrozen for a call to freeze).

Models must be compiled again after weights are frozen or unfrozen.

Examples

# instantiate a VGG16 model
conv_base <- application_vgg16(
  weights = "imagenet",
  include_top = FALSE,
  input_shape = c(150, 150, 3)
)

# freeze it's weights
freeze_weights(conv_base)

# Note the "Trainable" column
conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 0 (0.00 B)
##  Non-trainable params: 14,714,688 (56.13 MB)


# create a composite model that includes the base + more layers
model <- keras_model_sequential(input_batch_shape = shape(conv_base$input)) |>
  conv_base() |>
  layer_flatten() |>
  layer_dense(units = 256, activation = "relu") |>
  layer_dense(units = 1, activation = "sigmoid")

# compile
model |> compile(
  loss = "binary_crossentropy",
  optimizer = optimizer_rmsprop(learning_rate = 2e-5),
  metrics = c("accuracy")
)

model

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 2,097,665 (8.00 MB)
##  Non-trainable params: 14,714,688 (56.13 MB)

print(model, expand_nested = TRUE)

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > input_layer            | (None, 150, 150, 3)   |          0 |   -   |
## | (InputLayer)                |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv1 (Conv2D)  | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv2 (Conv2D)  | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_pool            | (None, 75, 75, 64)    |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv1 (Conv2D)  | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv2 (Conv2D)  | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_pool            | (None, 37, 37, 128)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv1 (Conv2D)  | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv2 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv3 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_pool            | (None, 18, 18, 256)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv1 (Conv2D)  | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv2 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv3 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_pool            | (None, 9, 9, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv1 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv2 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv3 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_pool            | (None, 4, 4, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 2,097,665 (8.00 MB)
##  Non-trainable params: 14,714,688 (56.13 MB)



# unfreeze weights from "block5_conv1" on
unfreeze_weights(conv_base, from = "block5_conv1")

# compile again since we froze or unfroze weights
model |> compile(
  loss = "binary_crossentropy",
  optimizer = optimizer_rmsprop(learning_rate = 2e-5),
  metrics = c("accuracy")
)

conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,079,424 (27.01 MB)
##  Non-trainable params: 7,635,264 (29.13 MB)

print(model, expand_nested = TRUE)

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > input_layer            | (None, 150, 150, 3)   |          0 |   -   |
## | (InputLayer)                |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv1 (Conv2D)  | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv2 (Conv2D)  | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_pool            | (None, 75, 75, 64)    |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv1 (Conv2D)  | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv2 (Conv2D)  | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_pool            | (None, 37, 37, 128)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv1 (Conv2D)  | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv2 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv3 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_pool            | (None, 18, 18, 256)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv1 (Conv2D)  | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv2 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv3 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_pool            | (None, 9, 9, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv1 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv2 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv3 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_pool            | (None, 4, 4, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 9,177,089 (35.01 MB)
##  Non-trainable params: 7,635,264 (29.13 MB)


# freeze only the last 5 layers
freeze_weights(conv_base, from = -5)
conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,635,264 (29.13 MB)
##  Non-trainable params: 7,079,424 (27.01 MB)

# freeze only the last 5 layers, a different way
unfreeze_weights(conv_base, to = -6)
conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,635,264 (29.13 MB)
##  Non-trainable params: 7,079,424 (27.01 MB)


# Freeze only layers of a certain type, e.g, BatchNorm layers
batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)

model <- application_efficientnet_b0()
freeze_weights(model, which = is_batch_norm_layer)
# print(model)

# equivalent to:
for(layer in model$layers) {
  if(is_batch_norm_layer(layer))
    layer$trainable <- FALSE
  else
    layer$trainable <- TRUE
}