Skip to contents

This method allows retrieval of different model variables (trainable, non-trainable, optimizer, and metrics). The variables are returned in a nested dictionary format, where the keys correspond to the variable names and the values are the nested representations of the variables.

Usage

get_state_tree(object, value_format = "backend_tensor")

Arguments

object

A Keras Model.

value_format

One of "backend_tensor", "numpy_array", "array". The kind of array to return as the leaves of the nested state tree.

Value

A named list containing the nested representations of the requested variables. The names are the variable names, and the values are the corresponding nested named lists.

Examples

model <- keras_model_sequential(name = "my_sequential",
                                input_shape = c(1),
                                input_name = "my_input") |>
  layer_dense(1, activation = "sigmoid", name = "my_dense")

model |> compile(optimizer="adam", loss="mse", metrics=c("mae"))
model |> fit(matrix(1), matrix(1), verbose = 0)
state_tree <- model |> get_state_tree()

The state_tree list returned looks like:

list(
  metrics_variables = list(
    loss = list(
      count = ...,
      total = ...
    ),
    mean_absolute_error = list(
      count = ...,
      total = ...
    )
  ),
  trainable_variables = list(
    my_sequential = list(
      my_dense = list(
        bias = ...,
        kernel = ...
      )
    )
  ),
  non_trainable_variables = list(),
  optimizer_variables = list(
    adam = list(
      iteration = ...,
      learning_rate = ...,
      my_sequential_my_dense_bias_momentum = ...,
      my_sequential_my_dense_bias_velocity = ...,
      my_sequential_my_dense_kernel_momentum = ...,
      my_sequential_my_dense_kernel_velocity = ...
    )
  )
)

For example:

str(state_tree)

## List of 4
##  $ trainable_variables    :List of 1
##   ..$ my_sequential:List of 1
##   .. ..$ my_dense:List of 2
##   .. .. ..$ kernel:<tf.Variable 'my_sequential/my_dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.8338491]], dtype=float32)>
##   .. .. ..$ bias  :<tf.Variable 'my_sequential/my_dense/bias:0' shape=(1) dtype=float32, numpy=array([0.00099998], dtype=float32)>
##  $ non_trainable_variables: Named list()
##  $ optimizer_variables    :List of 1
##   ..$ adam:List of 6
##   .. ..$ iteration                             :<tf.Variable 'adam/iteration:0' shape=() dtype=int64, numpy=1>
##   .. ..$ learning_rate                         :<tf.Variable 'adam/learning_rate:0' shape=() dtype=float32, numpy=0.0010000000474974513>
##   .. ..$ my_sequential_my_dense_kernel_momentum:<tf.Variable 'adam/my_sequential_my_dense_kernel_momentum:0' shape=(1, 1) dtype=float32, numpy=array([[-0.02943518]], dtype=float32)>
##   .. ..$ my_sequential_my_dense_kernel_velocity:<tf.Variable 'adam/my_sequential_my_dense_kernel_velocity:0' shape=(1, 1) dtype=float32, numpy=array([[8.664299e-05]], dtype=float32)>
##   .. ..$ my_sequential_my_dense_bias_momentum  :<tf.Variable 'adam/my_sequential_my_dense_bias_momentum:0' shape=(1) dtype=float32, numpy=array([-0.02943518], dtype=float32)>
##   .. ..$ my_sequential_my_dense_bias_velocity  :<tf.Variable 'adam/my_sequential_my_dense_bias_velocity:0' shape=(1) dtype=float32, numpy=array([8.664299e-05], dtype=float32)>
##  $ metrics_variables      :List of 2
##   ..$ loss               :List of 2
##   .. ..$ total:<tf.Variable 'loss/total:0' shape=() dtype=float32, numpy=0.4863377809524536>
##   .. ..$ count:<tf.Variable 'loss/count:0' shape=() dtype=float32, numpy=1.0>
##   ..$ mean_absolute_error:List of 2
##   .. ..$ total:<tf.Variable 'mean_absolute_error/total:0' shape=() dtype=float32, numpy=0.6973792314529419>
##   .. ..$ count:<tf.Variable 'mean_absolute_error_1/count:0' shape=() dtype=float32, numpy=1.0>