focal_loss.sparse_categorical_focal_loss

focal_loss.sparse_categorical_focal_loss(y_true, y_pred, gamma, *, class_weight: Optional[Any] = None, from_logits: bool = False, axis: int = -1) → tensorflow.python.framework.ops.Tensor[source]

Focal loss function for multiclass classification with integer labels.

This loss function generalizes multiclass softmax cross-entropy by introducing a hyperparameter called the focusing parameter that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.

See binary_focal_loss() for a description of the focal loss in the binary setting, as presented in the original work [1].

In the multiclass setting, with integer labels \(y\), focal loss is defined as

\[L(y, \hat{\mathbf{p}}) = -\left(1 - \hat{p}_y\right)^\gamma \log(\hat{p}_y)\]

where

  • \(y \in \{0, \ldots, K - 1\}\) is an integer class label (\(K\) denotes the number of classes),
  • \(\hat{\mathbf{p}} = (\hat{p}_0, \ldots, \hat{p}_{K-1}) \in [0, 1]^K\) is a vector representing an estimated probability distribution over the \(K\) classes,
  • \(\gamma\) (gamma, not \(y\)) is the focusing parameter that specifies how much higher-confidence correct predictions contribute to the overall loss (the higher the \(\gamma\), the higher the rate at which easy-to-classify examples are down-weighted).

The usual multiclass softmax cross-entropy loss is recovered by setting \(\gamma = 0\).

Parameters:
  • y_true (tensor-like) – Integer class labels.
  • y_pred (tensor-like) – Either probabilities or logits, depending on the from_logits parameter.
  • gamma (float or tensor-like of shape (K,)) – The focusing parameter \(\gamma\). Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative. This can be a one-dimensional tensor, in which case it specifies a focusing parameter for each class.
  • class_weight (tensor-like of shape (K,)) – Weighting factor for each of the \(k\) classes. If not specified, then all classes are weighted equally.
  • from_logits (bool, optional) – Whether y_pred contains logits or probabilities.
  • axis (int, optional) – Channel axis in the y_pred tensor.
Returns:

The focal loss for each example.

Return type:

tf.Tensor

Examples

This function computes the per-example focal loss between a one-dimensional integer label vector and a two-dimensional prediction matrix:

>>> import numpy as np
>>> from focal_loss import sparse_categorical_focal_loss
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2)
>>> np.set_printoptions(precision=3)
>>> print(loss.numpy())
[0.009 0.032 0.082]

Warning

This function does not reduce its output to a scalar, so it cannot be passed to tf.keras.Model.compile() as a loss argument. Instead, use the wrapper class SparseCategoricalFocalLoss.

References

[1]T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint)

See also

SparseCategoricalFocalLoss()
A wrapper around this function that makes it a tf.keras.losses.Loss.