Freeze weights in a model or layer so that they are no longer trainable.
Usage
freeze_weights(object, from = NULL, to = NULL, which = NULL)
unfreeze_weights(object, from = NULL, to = NULL, which = NULL)
Value
The input object
with frozen weights is returned, invisibly. Note,
object
is modified in place, and the return value is only provided to
make usage with the pipe convenient.
Note
The from
and to
layer arguments are both inclusive.
When applied to a model, the freeze or unfreeze is a global operation over all layers in the model (i.e. layers not within the specified range will be set to the opposite value, e.g. unfrozen for a call to freeze).
Models must be compiled again after weights are frozen or unfrozen.
Examples
# instantiate a VGG16 model
conv_base <- application_vgg16(
weights = "imagenet",
include_top = FALSE,
input_shape = c(150, 150, 3)
)
# freeze it's weights
freeze_weights(conv_base)
# Note the "Trainable" column
conv_base
## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer) | (None, 150, 150, 3) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D) | (None, 75, 75, 64) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D) | (None, 37, 37, 128) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D) | (None, 18, 18, 256) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D) | (None, 9, 9, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D) | (None, 4, 4, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 14,714,688 (56.13 MB)
## Trainable params: 0 (0.00 B)
## Non-trainable params: 14,714,688 (56.13 MB)
# create a composite model that includes the base + more layers
model <- keras_model_sequential(input_batch_shape = shape(conv_base$input)) |>
conv_base() |>
layer_flatten() |>
layer_dense(units = 256, activation = "relu") |>
layer_dense(units = 1, activation = "sigmoid")
# compile
model |> compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(learning_rate = 2e-5),
metrics = c("accuracy")
)
model
## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional) | (None, 4, 4, 512) | 14,714,688 | N |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten) | (None, 8192) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense) | (None, 256) | 2,097,408 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense) | (None, 1) | 257 | Y |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 16,812,353 (64.13 MB)
## Trainable params: 2,097,665 (8.00 MB)
## Non-trainable params: 14,714,688 (56.13 MB)
print(model, expand_nested = TRUE)
## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional) | (None, 4, 4, 512) | 14,714,688 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > input_layer | (None, 150, 150, 3) | 0 | - |
## | (InputLayer) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_pool | (None, 75, 75, 64) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_pool | (None, 37, 37, 128) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_pool | (None, 18, 18, 256) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_pool | (None, 9, 9, 512) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_pool | (None, 4, 4, 512) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten) | (None, 8192) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense) | (None, 256) | 2,097,408 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense) | (None, 1) | 257 | Y |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 16,812,353 (64.13 MB)
## Trainable params: 2,097,665 (8.00 MB)
## Non-trainable params: 14,714,688 (56.13 MB)
# unfreeze weights from "block5_conv1" on
unfreeze_weights(conv_base, from = "block5_conv1")
# compile again since we froze or unfroze weights
model |> compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(learning_rate = 2e-5),
metrics = c("accuracy")
)
conv_base
## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer) | (None, 150, 150, 3) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D) | (None, 75, 75, 64) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D) | (None, 37, 37, 128) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D) | (None, 18, 18, 256) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D) | (None, 9, 9, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D) | (None, 4, 4, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 14,714,688 (56.13 MB)
## Trainable params: 7,079,424 (27.01 MB)
## Non-trainable params: 7,635,264 (29.13 MB)
print(model, expand_nested = TRUE)
## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional) | (None, 4, 4, 512) | 14,714,688 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | > input_layer | (None, 150, 150, 3) | 0 | - |
## | (InputLayer) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block1_pool | (None, 75, 75, 64) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block2_pool | (None, 37, 37, 128) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block3_pool | (None, 18, 18, 256) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | > block4_pool | (None, 9, 9, 512) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | > block5_pool | (None, 4, 4, 512) | 0 | - |
## | (MaxPooling2D) | | | |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten) | (None, 8192) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense) | (None, 256) | 2,097,408 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense) | (None, 1) | 257 | Y |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 16,812,353 (64.13 MB)
## Trainable params: 9,177,089 (35.01 MB)
## Non-trainable params: 7,635,264 (29.13 MB)
# freeze only the last 5 layers
freeze_weights(conv_base, from = -5)
conv_base
## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer) | (None, 150, 150, 3) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D) | (None, 75, 75, 64) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D) | (None, 37, 37, 128) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D) | (None, 18, 18, 256) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D) | (None, 9, 9, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D) | (None, 4, 4, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 14,714,688 (56.13 MB)
## Trainable params: 7,635,264 (29.13 MB)
## Non-trainable params: 7,079,424 (27.01 MB)
# freeze only the last 5 layers, a different way
unfreeze_weights(conv_base, to = -6)
conv_base
## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type) | Output Shape | Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer) | (None, 150, 150, 3) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D) | (None, 150, 150, 64) | 1,792 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D) | (None, 150, 150, 64) | 36,928 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D) | (None, 75, 75, 64) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D) | (None, 75, 75, 128) | 73,856 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D) | (None, 75, 75, 128) | 147,584 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D) | (None, 37, 37, 128) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D) | (None, 37, 37, 256) | 295,168 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D) | (None, 37, 37, 256) | 590,080 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D) | (None, 37, 37, 256) | 590,080 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D) | (None, 18, 18, 256) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D) | (None, 18, 18, 512) | 1,180,160 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D) | (None, 18, 18, 512) | 2,359,808 | Y |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D) | (None, 9, 9, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D) | (None, 9, 9, 512) | 2,359,808 | N |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D) | (None, 4, 4, 512) | 0 | - |
## +-----------------------------+-----------------------+------------+-------+
## Total params: 14,714,688 (56.13 MB)
## Trainable params: 7,635,264 (29.13 MB)
## Non-trainable params: 7,079,424 (27.01 MB)
# Freeze only layers of a certain type, e.g, BatchNorm layers
batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)
model <- application_efficientnet_b0()
freeze_weights(model, which = is_batch_norm_layer)
# print(model)
# equivalent to:
for(layer in model$layers) {
if(is_batch_norm_layer(layer))
layer$trainable <- FALSE
else
layer$trainable <- TRUE
}