Approximates the AUC (Area under the curve) of the ROC or PR curves.
Source:R/metrics.R
metric_auc.RdThe AUC (Area under the curve) of the ROC (Receiver operating characteristic; default) or PR (Precision Recall) curves are quality measures of binary classifiers. Unlike the accuracy, and like cross-entropy losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric accumulation phrase, predictions are accumulated within predefined buckets by value. The AUC is then computed by interpolating per-bucket averages. These buckets define the evaluated operational points.
This metric creates four local variables, true_positives,
true_negatives, false_positives and false_negatives that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as auc, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The num_thresholds
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on num_thresholds. The
thresholds parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, predictions should be
distributed approximately uniformly in the range [0, 1] (if
from_logits=FALSE). The quality of the AUC approximation may be poor if
this is not the case. Setting summation_method to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If sample_weight is NULL, weights default to 1.
Use sample_weight of 0 to mask values.
Usage
metric_auc(
...,
num_thresholds = 200L,
curve = "ROC",
summation_method = "interpolation",
name = NULL,
dtype = NULL,
thresholds = NULL,
multi_label = FALSE,
num_labels = NULL,
label_weights = NULL,
from_logits = FALSE
)Arguments
- ...
For forward/backward compatability.
- num_thresholds
(Optional) The number of thresholds to use when discretizing the roc curve. Values must be > 1. Defaults to
200.- curve
(Optional) Specifies the name of the curve to be computed,
'ROC'(default) or'PR'for the Precision-Recall-curve.- summation_method
(Optional) Specifies the Riemann summation method used. 'interpolation' (default) applies mid-point summation scheme for
ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that is precision (see Davis & Goadrich 2006 for details); 'minoring' applies left summation for increasing intervals and right summation for decreasing intervals; 'majoring' does the opposite.- name
(Optional) string name of the metric instance.
- dtype
(Optional) data type of the metric result.
- thresholds
(Optional) A list of floating point values to use as the thresholds for discretizing the curve. If set, the
num_thresholdsparameter is ignored. Values should be in[0, 1]. Endpoint thresholds equal to {-epsilon,1+epsilon} for a small positive epsilon value will be automatically included with these to correctly handle predictions equal to exactly 0 or 1.- multi_label
boolean indicating whether multilabel data should be treated as such, wherein AUC is computed separately for each label and then averaged across labels, or (when
FALSE) if the data should be flattened into a single label before AUC computation. In the latter case, when multilabel data is passed to AUC, each label-prediction pair is treated as an individual data point. Should be set to `FALSE`` for multi-class data.- num_labels
(Optional) The number of labels, used when
multi_labelis TRUE. Ifnum_labelsis not specified, then state variables get created on the first call toupdate_state.- label_weights
(Optional) list, array, or tensor of non-negative weights used to compute AUCs for multilabel data. When
multi_labelis TRUE, the weights are applied to the individual label AUCs when they are averaged to produce the multi-label AUC. When it's FALSE, they are used to weight the individual label predictions in computing the confusion matrix on the flattened data. Note that this is unlikeclass_weightsin thatclass_weightsweights the example depending on the value of its label, whereaslabel_weightsdepends only on the index of that label before flattening; thereforelabel_weightsshould not be used for multi-class data.- from_logits
boolean indicating whether the predictions (
y_predinupdate_state) are probabilities or sigmoid logits. As a rule of thumb, when using a keras loss, thefrom_logitsconstructor argument of the loss should match the AUCfrom_logitsconstructor argument.
Value
a Metric instance is returned. The Metric instance can be passed
directly to compile(metrics = ), or used as a standalone object. See
?Metric for example usage.
Usage
Standalone usage:
m <- metric_auc(num_thresholds = 3)
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9))
# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
# auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
# = 0.75
m$result()m$reset_state()
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9),
sample_weight=c(1, 0, 0, 1))
m$result()Usage with compile() API:
# Reports the AUC of a model outputting a probability.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(),
metrics = list(metric_auc())
)
# Reports the AUC of a model outputting a logit.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(from_logits = TRUE),
metrics = list(metric_auc(from_logits = TRUE))
)See also
Other confusion metrics: metric_false_negatives() metric_false_positives() metric_precision() metric_precision_at_recall() metric_recall() metric_recall_at_precision() metric_sensitivity_at_specificity() metric_specificity_at_sensitivity() metric_true_negatives() metric_true_positives()
Other metrics: Metric() custom_metric() metric_binary_accuracy() metric_binary_crossentropy() metric_binary_focal_crossentropy() metric_binary_iou() metric_categorical_accuracy() metric_categorical_crossentropy() metric_categorical_focal_crossentropy() metric_categorical_hinge() metric_concordance_correlation() metric_cosine_similarity() metric_f1_score() metric_false_negatives() metric_false_positives() metric_fbeta_score() metric_hinge() metric_huber() metric_iou() metric_kl_divergence() metric_log_cosh() metric_log_cosh_error() metric_mean() metric_mean_absolute_error() metric_mean_absolute_percentage_error() metric_mean_iou() metric_mean_squared_error() metric_mean_squared_logarithmic_error() metric_mean_wrapper() metric_one_hot_iou() metric_one_hot_mean_iou() metric_pearson_correlation() metric_poisson() metric_precision() metric_precision_at_recall() metric_r2_score() metric_recall() metric_recall_at_precision() metric_root_mean_squared_error() metric_sensitivity_at_specificity() metric_sparse_categorical_accuracy() metric_sparse_categorical_crossentropy() metric_sparse_top_k_categorical_accuracy() metric_specificity_at_sensitivity() metric_squared_hinge() metric_sum() metric_top_k_categorical_accuracy() metric_true_negatives() metric_true_positives()