This method allows retrieval of different model variables (trainable, non-trainable, optimizer, and metrics). The variables are returned in a nested dictionary format, where the keys correspond to the variable names and the values are the nested representations of the variables.
Value
A named list containing the nested representations of the requested variables. The names are the variable names, and the values are the corresponding nested named lists.
Examples
model <- keras_model_sequential(name = "my_sequential",
input_shape = c(1),
input_name = "my_input") |>
layer_dense(1, activation = "sigmoid", name = "my_dense")
model |> compile(optimizer="adam", loss="mse", metrics=c("mae"))
model |> fit(matrix(1), matrix(1), verbose = 0)
state_tree <- model |> get_state_tree()
The state_tree
list returned looks like:
list(
metrics_variables = list(
loss = list(
count = ...,
total = ...
),
mean_absolute_error = list(
count = ...,
total = ...
)
),
trainable_variables = list(
my_sequential = list(
my_dense = list(
bias = ...,
kernel = ...
)
)
),
non_trainable_variables = list(),
optimizer_variables = list(
adam = list(
iteration = ...,
learning_rate = ...,
my_sequential_my_dense_bias_momentum = ...,
my_sequential_my_dense_bias_velocity = ...,
my_sequential_my_dense_kernel_momentum = ...,
my_sequential_my_dense_kernel_velocity = ...
)
)
)
For example:
str(state_tree)
## List of 4
## $ trainable_variables :List of 1
## ..$ my_sequential:List of 1
## .. ..$ my_dense:List of 2
## .. .. ..$ kernel:<tf.Variable 'my_sequential/my_dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.8338491]], dtype=float32)>
## .. .. ..$ bias :<tf.Variable 'my_sequential/my_dense/bias:0' shape=(1) dtype=float32, numpy=array([0.00099998], dtype=float32)>
## $ non_trainable_variables: Named list()
## $ optimizer_variables :List of 1
## ..$ adam:List of 6
## .. ..$ iteration :<tf.Variable 'adam/iteration:0' shape=() dtype=int64, numpy=1>
## .. ..$ learning_rate :<tf.Variable 'adam/learning_rate:0' shape=() dtype=float32, numpy=0.0010000000474974513>
## .. ..$ my_sequential_my_dense_kernel_momentum:<tf.Variable 'adam/my_sequential_my_dense_kernel_momentum:0' shape=(1, 1) dtype=float32, numpy=array([[-0.02943518]], dtype=float32)>
## .. ..$ my_sequential_my_dense_kernel_velocity:<tf.Variable 'adam/my_sequential_my_dense_kernel_velocity:0' shape=(1, 1) dtype=float32, numpy=array([[8.664299e-05]], dtype=float32)>
## .. ..$ my_sequential_my_dense_bias_momentum :<tf.Variable 'adam/my_sequential_my_dense_bias_momentum:0' shape=(1) dtype=float32, numpy=array([-0.02943518], dtype=float32)>
## .. ..$ my_sequential_my_dense_bias_velocity :<tf.Variable 'adam/my_sequential_my_dense_bias_velocity:0' shape=(1) dtype=float32, numpy=array([8.664299e-05], dtype=float32)>
## $ metrics_variables :List of 2
## ..$ loss :List of 2
## .. ..$ total:<tf.Variable 'loss/total:0' shape=() dtype=float32, numpy=0.4863377809524536>
## .. ..$ count:<tf.Variable 'loss/count:0' shape=() dtype=float32, numpy=1.0>
## ..$ mean_absolute_error:List of 2
## .. ..$ total:<tf.Variable 'mean_absolute_error/total:0' shape=() dtype=float32, numpy=0.6973792314529419>
## .. ..$ count:<tf.Variable 'mean_absolute_error_1/count:0' shape=() dtype=float32, numpy=1.0>
See also
Other model functions: get_config()
get_layer()
keras_model()
keras_model_sequential()
pop_layer()
set_state_tree()
summary.keras.src.models.model.Model()