Approximates the AUC (Area under the curve) of the ROC or PR curves.
Source:R/metrics.R
metric_auc.Rd
The AUC (Area under the curve) of the ROC (Receiver operating characteristic; default) or PR (Precision Recall) curves are quality measures of binary classifiers. Unlike the accuracy, and like cross-entropy losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric accumulation phrase, predictions are accumulated within predefined buckets by value. The AUC is then computed by interpolating per-bucket averages. These buckets define the evaluated operational points.
This metric creates four local variables, true_positives
,
true_negatives
, false_positives
and false_negatives
that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as auc
, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The num_thresholds
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on num_thresholds
. The
thresholds
parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, predictions
should be
distributed approximately uniformly in the range [0, 1]
(if
from_logits=FALSE
). The quality of the AUC approximation may be poor if
this is not the case. Setting summation_method
to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If sample_weight
is NULL
, weights default to 1.
Use sample_weight
of 0 to mask values.
Usage
metric_auc(
...,
num_thresholds = 200L,
curve = "ROC",
summation_method = "interpolation",
name = NULL,
dtype = NULL,
thresholds = NULL,
multi_label = FALSE,
num_labels = NULL,
label_weights = NULL,
from_logits = FALSE
)
Arguments
- ...
For forward/backward compatability.
- num_thresholds
(Optional) The number of thresholds to use when discretizing the roc curve. Values must be > 1. Defaults to
200
.- curve
(Optional) Specifies the name of the curve to be computed,
'ROC'
(default) or'PR'
for the Precision-Recall-curve.- summation_method
(Optional) Specifies the Riemann summation method used. 'interpolation' (default) applies mid-point summation scheme for
ROC
. For PR-AUC, interpolates (true/false) positives but not the ratio that is precision (see Davis & Goadrich 2006 for details); 'minoring' applies left summation for increasing intervals and right summation for decreasing intervals; 'majoring' does the opposite.- name
(Optional) string name of the metric instance.
- dtype
(Optional) data type of the metric result.
- thresholds
(Optional) A list of floating point values to use as the thresholds for discretizing the curve. If set, the
num_thresholds
parameter is ignored. Values should be in[0, 1]
. Endpoint thresholds equal to {-epsilon
,1+epsilon
} for a small positive epsilon value will be automatically included with these to correctly handle predictions equal to exactly 0 or 1.- multi_label
boolean indicating whether multilabel data should be treated as such, wherein AUC is computed separately for each label and then averaged across labels, or (when
FALSE
) if the data should be flattened into a single label before AUC computation. In the latter case, when multilabel data is passed to AUC, each label-prediction pair is treated as an individual data point. Should be set to `FALSE“ for multi-class data.- num_labels
(Optional) The number of labels, used when
multi_label
is TRUE. Ifnum_labels
is not specified, then state variables get created on the first call toupdate_state
.- label_weights
(Optional) list, array, or tensor of non-negative weights used to compute AUCs for multilabel data. When
multi_label
is TRUE, the weights are applied to the individual label AUCs when they are averaged to produce the multi-label AUC. When it's FALSE, they are used to weight the individual label predictions in computing the confusion matrix on the flattened data. Note that this is unlikeclass_weights
in thatclass_weights
weights the example depending on the value of its label, whereaslabel_weights
depends only on the index of that label before flattening; thereforelabel_weights
should not be used for multi-class data.- from_logits
boolean indicating whether the predictions (
y_pred
inupdate_state
) are probabilities or sigmoid logits. As a rule of thumb, when using a keras loss, thefrom_logits
constructor argument of the loss should match the AUCfrom_logits
constructor argument.
Value
a Metric
instance is returned. The Metric
instance can be passed
directly to compile(metrics = )
, or used as a standalone object. See
?Metric
for example usage.
Usage
Standalone usage:
m <- metric_auc(num_thresholds = 3)
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9))
# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
# auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
# = 0.75
m$result()
m$reset_state()
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9),
sample_weight=c(1, 0, 0, 1))
m$result()
Usage with compile()
API:
# Reports the AUC of a model outputting a probability.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(),
metrics = list(metric_auc())
)
# Reports the AUC of a model outputting a logit.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(from_logits = TRUE),
metrics = list(metric_auc(from_logits = TRUE))
)
See also
Other confusion metrics: metric_false_negatives()
metric_false_positives()
metric_precision()
metric_precision_at_recall()
metric_recall()
metric_recall_at_precision()
metric_sensitivity_at_specificity()
metric_specificity_at_sensitivity()
metric_true_negatives()
metric_true_positives()
Other metrics: Metric()
custom_metric()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()