Configure a model for training.
Usage
# S3 method for class 'keras.src.models.model.Model'
compile(
object,
optimizer = "rmsprop",
loss = NULL,
metrics = NULL,
...,
loss_weights = NULL,
weighted_metrics = NULL,
run_eagerly = FALSE,
steps_per_execution = 1L,
jit_compile = "auto",
auto_scale_loss = TRUE
)
Arguments
- object
Keras model object
- optimizer
String (name of optimizer) or optimizer instance. See
optimizer_*
family.- loss
Loss function. May be:
a string (name of builtin loss function),
a custom function, or
a
Loss
instance (returned by theloss_*
family of functions).
A loss function is any callable with the signature
loss = fn(y_true, y_pred)
, wherey_true
are the ground truth values, andy_pred
are the model's predictions.y_true
should have shape(batch_size, d1, .. dN)
(except in the case of sparse loss functions such as sparse categorical crossentropy which expects integer arrays of shape(batch_size, d1, .. dN-1)
).y_pred
should have shape(batch_size, d1, .. dN)
. The loss function should return a float tensor.- metrics
List of metrics to be evaluated by the model during training and testing. Each of these can be:
a string (name of a built-in function),
a function, optionally with a
"name"
attribute ora
Metric()
instance. See themetric_*
family of functions.
Typically you will use
metrics = c('accuracy')
. A function is any callable with the signatureresult = fn(y_true, y_pred)
. To specify different metrics for different outputs of a multi-output model, you could also pass a named list, such asmetrics = list(a = 'accuracy', b = c('accuracy', 'mse'))
. You can also pass a list to specify a metric or a list of metrics for each output, such asmetrics = list(c('accuracy'), c('accuracy', 'mse'))
ormetrics = list('accuracy', c('accuracy', 'mse'))
. When you pass the strings'accuracy'
or'acc'
, we convert this to one ofmetric_binary_accuracy()
,metric_categorical_accuracy()
,metric_sparse_categorical_accuracy()
based on the shapes of the targets and of the model output. A similar conversion is done for the strings"crossentropy"
and"ce"
as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via theweighted_metrics
argument instead.If providing an anonymous R function, you can customize the printed name during training by assigning
attr(<fn>, "name") <- "my_custom_metric_name"
, or by callingcustom_metric("my_custom_metric_name", <fn>)
- ...
Additional arguments passed on to the
compile()
model method.- loss_weights
Optional list (named or unnamed) specifying scalar coefficients (R numerics) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the
loss_weights
coefficients. If an unnamed list, it is expected to have a 1:1 mapping to the model's outputs. If a named list, it is expected to map output names (strings) to scalar coefficients.- weighted_metrics
List of metrics to be evaluated and weighted by
sample_weight
orclass_weight
during training and testing.- run_eagerly
Bool. If
TRUE
, this model's forward pass will never be compiled. It is recommended to leave this asFALSE
when training (for best performance), and to set it toTRUE
when debugging.- steps_per_execution
Int. The number of batches to run during each a single compiled function call. Running multiple batches inside a single compiled function call can greatly improve performance on TPUs or small models with a large R/Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if
steps_per_execution
is set toN
,Callback$on_batch_begin
andCallback$on_batch_end
methods will only be called everyN
batches (i.e. before/after each compiled function execution). Not supported with the PyTorch backend.- jit_compile
Bool or
"auto"
. Whether to use XLA compilation when compiling a model. Forjax
andtensorflow
backends,jit_compile="auto"
enables XLA compilation if the model supports it, and disabled otherwise. Fortorch
backend,"auto"
will default to eager execution andjit_compile=True
will run withtorch.compile
with the"inductor"
backend.- auto_scale_loss
Bool. If
TRUE
and the model dtype policy is"mixed_float16"
, the passed optimizer will be automatically wrapped in aLossScaleOptimizer
, which will dynamically scale the loss to prevent underflow.
Value
This is called primarily for the side effect of modifying object
in-place. The first argument object
is also returned, invisibly, to
enable usage with the pipe.
Examples
model |> compile(
optimizer = optimizer_adam(learning_rate = 1e-3),
loss = loss_binary_crossentropy(),
metrics = c(metric_binary_accuracy(),
metric_false_negatives())
)