Use this to define a custom loss class. Note, in most cases you do not need
to subclass Loss
to define a custom loss: you can also pass a bare R
function, or a named R function defined with custom_metric()
, as a loss
function to compile()
.
Usage
Loss(
classname,
call = NULL,
...,
public = list(),
private = list(),
inherit = NULL,
parent_env = parent.frame()
)
Arguments
- classname
String, the name of the custom class. (Conventionally, CamelCase).
- call
function(y_true, y_pred)
Method to be implemented by subclasses: Function that contains the logic for loss calculation using
y_true
,y_pred
.- ..., public
Additional methods or public members of the custom class.
- private
Named list of R objects (typically, functions) to include in instance private environments.
private
methods will have all the same symbols in scope as public methods (See section "Symbols in Scope"). Each instance will have it's ownprivate
environment. Any objects inprivate
will be invisible from the Keras framework and the Python runtime.- inherit
What the custom class will subclass. By default, the base keras class.
- parent_env
The R environment that all class methods will have as a grandparent.
Details
Example subclass implementation:
loss_custom_mse <- Loss(
classname = "CustomMeanSquaredError",
call = function(y_true, y_pred) {
op_mean(op_square(y_pred - y_true), axis = -1)
}
)
# Usage in compile()
model <- keras_model_sequential(input_shape = 10) |> layer_dense(10)
model |> compile(loss = loss_custom_mse())
# Standalone usage
mse <- loss_custom_mse(name = "my_custom_mse_instance")
y_true <- op_arange(20) |> op_reshape(c(4, 5))
y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2
(loss <- mse(y_true, y_pred))
Methods defined by base Loss
class:
-
initialize(name=NULL, reduction="sum_over_batch_size", dtype=NULL)
Args:
name
: Optional name for the loss instance.reduction
: Type of reduction to apply to the loss. In almost all cases this should be"sum_over_batch_size"
. Supported options are"sum"
,"sum_over_batch_size"
orNULL
.dtype
: The dtype of the loss's computations. Defaults toNULL
, which means usingconfig_floatx()
.config_floatx()
is a"float32"
unless set to different value (viaconfig_set_floatx()
). If akeras$DTypePolicy
is provided, then thecompute_dtype
will be utilized.
-
Call the loss instance as a function, optionally with
sample_weight
.
Symbols in scope
All R function custom methods (public and private) will have the following symbols in scope:
self
: The custom class instance.super
: The custom class superclass.private
: An R environment specific to the class instance. Any objects assigned here are invisible to the Keras framework.__class__
andas.symbol(classname)
: the custom class type object.
See also
Other losses: loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()