Skip to contents

Formula:

loss = mean(l2norm(y_true - mean(y_true) * l2norm(y_pred - mean(y_pred)))

PCC measures the linear relationship between the true values (y_true) and the predicted values (y_pred). The coefficient ranges from -1 to 1, where a value of 1 implies a perfect positive linear correlation, 0 indicates no linear correlation, and -1 indicates a perfect negative linear correlation.

This metric is widely used in regression tasks where the strength of the linear relationship between predictions and true labels is an important evaluation criterion.

Usage

metric_pearson_correlation(
  y_true,
  y_pred,
  axis = -1L,
  ...,
  name = "pearson_correlation",
  dtype = NULL
)

Arguments

y_true

Tensor of true targets.

y_pred

Tensor of predicted targets.

axis

(Optional) integer or tuple of integers of the axis/axes along which to compute the metric. Defaults to -1.

...

For forward/backward compatability.

name

(Optional) string name of the metric instance.

dtype

(Optional) data type of the metric result.

Examples

pcc <- metric_pearson_correlation(axis = -1)
y_true <- rbind(c(0, 1, 0.5),
                c(1, 1, 0.2))
y_pred <- rbind(c(0.1, 0.9, 0.5),
                c(1, 0.9, 0.2))
pcc$update_state(y_true, y_pred)
pcc$result()

## tf.Tensor(0.99669963, shape=(), dtype=float32)

# equivalent operation using R's stats::cor()
mean(sapply(1:nrow(y_true), function(i) {
  cor(y_true[i, ], y_pred[i, ])
}))

## [1] 0.9966996

Usage with compile() API:

model |> compile(
  optimizer = 'sgd',
  loss = 'mean_squared_error',
  metrics = c(keras.metrics.PearsonCorrelation())
)

See also

Other regression metrics:
metric_concordance_correlation()
metric_cosine_similarity()
metric_log_cosh_error()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_r2_score()
metric_root_mean_squared_error()

Other metrics:
Metric()
custom_metric()
metric_auc()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_concordance_correlation()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()