def update_state(self, y_true, y_pred, sample_weight=None):
"""
y_pred: shape (batch, ..., num_classes) or integer class labels
y_true: shape (batch, ...) with integer class labels
"""
y_true = keras.ops.convert_to_tensor(y_true)
y_pred = keras.ops.convert_to_tensor(y_pred)
if y_pred.shape is not None and len(y_pred.shape) > 1 and y_pred.shape[-1] == self.num_classes:
pred_labels = keras.ops.argmax(y_pred, axis=-1)
else:
pred_labels = y_pred
y_true_flat = keras.ops.cast(keras.ops.reshape(y_true, (-1,)), "int32")
pred_flat = keras.ops.cast(keras.ops.reshape(pred_labels, (-1,)), "int32")
invalid_true = keras.ops.logical_or(y_true_flat < 0, y_true_flat >= self.num_classes)
invalid_pred = keras.ops.logical_or(pred_flat < 0, pred_flat >= self.num_classes)
has_invalid = keras.ops.logical_or(
keras.ops.any(invalid_true), keras.ops.any(invalid_pred)
)
invalid_message = (
f"labels and predictions must be in [0, {self.num_classes - 1}]"
)
if keras.backend.backend() == "tensorflow":
import tensorflow as tf
if tf.executing_eagerly():
if bool(tf.get_static_value(has_invalid)):
raise ValueError(invalid_message)
else:
assertion = tf.debugging.assert_equal(
has_invalid,
False,
message=invalid_message,
)
with tf.control_dependencies([assertion]):
y_true_flat = tf.identity(y_true_flat)
pred_flat = tf.identity(pred_flat)
elif bool(keras.ops.convert_to_numpy(has_invalid)):
raise ValueError(invalid_message)
if sample_weight is not None:
sample_weight = keras.ops.convert_to_tensor(sample_weight, dtype=self._state_dtype)
sample_weight = keras.ops.reshape(sample_weight, (-1,))
batch_conf_matrix = metrics_utils.confusion_matrix(
labels=y_true_flat,
predictions=pred_flat,
num_classes=self.num_classes,
weights=sample_weight,
dtype=self._state_dtype,
)
self.conf_matrix.assign_add(batch_conf_matrix)