focal_loss
.sparse_categorical_focal_loss¶
-
focal_loss.
sparse_categorical_focal_loss
(y_true, y_pred, gamma, *, class_weight: Optional[Any] = None, from_logits: bool = False, axis: int = -1) → tensorflow.python.framework.ops.Tensor[source]¶ Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by introducing a hyperparameter called the focusing parameter that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.
See
binary_focal_loss()
for a description of the focal loss in the binary setting, as presented in the original work [1].In the multiclass setting, with integer labels \(y\), focal loss is defined as
\[L(y, \hat{\mathbf{p}}) = -\left(1 - \hat{p}_y\right)^\gamma \log(\hat{p}_y)\]where
- \(y \in \{0, \ldots, K - 1\}\) is an integer class label (\(K\) denotes the number of classes),
- \(\hat{\mathbf{p}} = (\hat{p}_0, \ldots, \hat{p}_{K-1}) \in [0, 1]^K\) is a vector representing an estimated probability distribution over the \(K\) classes,
- \(\gamma\) (gamma, not \(y\)) is the focusing parameter that specifies how much higher-confidence correct predictions contribute to the overall loss (the higher the \(\gamma\), the higher the rate at which easy-to-classify examples are down-weighted).
The usual multiclass softmax cross-entropy loss is recovered by setting \(\gamma = 0\).
Parameters: - y_true (tensor-like) – Integer class labels.
- y_pred (tensor-like) – Either probabilities or logits, depending on the from_logits parameter.
- gamma (float or tensor-like of shape (K,)) – The focusing parameter \(\gamma\). Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative. This can be a one-dimensional tensor, in which case it specifies a focusing parameter for each class.
- class_weight (tensor-like of shape (K,)) – Weighting factor for each of the \(k\) classes. If not specified, then all classes are weighted equally.
- from_logits (bool, optional) – Whether y_pred contains logits or probabilities.
- axis (int, optional) – Channel axis in the y_pred tensor.
Returns: The focal loss for each example.
Return type: tf.Tensor
Examples
This function computes the per-example focal loss between a one-dimensional integer label vector and a two-dimensional prediction matrix:
>>> import numpy as np >>> from focal_loss import sparse_categorical_focal_loss >>> y_true = [0, 1, 2] >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] >>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2) >>> np.set_printoptions(precision=3) >>> print(loss.numpy()) [0.009 0.032 0.082]
Warning
This function does not reduce its output to a scalar, so it cannot be passed to
tf.keras.Model.compile()
as a loss argument. Instead, use the wrapper classSparseCategoricalFocalLoss
.References
[1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint) See also
SparseCategoricalFocalLoss()
- A wrapper around this function that makes it a
tf.keras.losses.Loss
.