focal_loss
.SparseCategoricalFocalLoss¶
-
class
focal_loss.
SparseCategoricalFocalLoss
(gamma, class_weight: Optional[Any] = None, from_logits: bool = False, **kwargs)[source]¶ Bases:
tensorflow.python.keras.losses.Loss
Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by introducing a hyperparameter \(\gamma\) (gamma), called the focusing parameter, that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.
This class is a wrapper around
sparse_categorical_focal_loss
. See the documentation there for details about this loss function.Parameters: - gamma (float or tensor-like of shape (K,)) – The focusing parameter \(\gamma\). Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative. This can be a one-dimensional tensor, in which case it specifies a focusing parameter for each class.
- class_weight (tensor-like of shape (K,)) – Weighting factor for each of the \(k\) classes. If not specified, then all classes are weighted equally.
- from_logits (bool, optional) – Whether model prediction will be logits or probabilities.
- **kwargs (keyword arguments) – Other keyword arguments for
tf.keras.losses.Loss
(e.g., name or reduction).
Examples
An instance of this class is a callable that takes a rank-one tensor of integer class labels y_true and a tensor of model predictions y_pred and returns a scalar tensor obtained by reducing the per-example focal loss (the default reduction is a batch-wise average).
>>> from focal_loss import SparseCategoricalFocalLoss >>> loss_func = SparseCategoricalFocalLoss(gamma=2) >>> y_true = [0, 1, 2] >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] >>> loss_func(y_true, y_pred) <tf.Tensor: shape=(), dtype=float32, numpy=0.040919524>
Use this class in the
tf.keras
API like any other multiclass classification loss function class that accepts integer labels found intf.keras.losses
(e.g.,tf.keras.losses.SparseCategoricalCrossentropy
:# Typical usage model = tf.keras.Model(...) model.compile( optimizer=..., loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss metrics=..., ) history = model.fit(...)
See also
sparse_categorical_focal_loss()
- The function that performs the focal loss computation, taking a label tensor and a prediction tensor and outputting a loss.
-
call
(y_true, y_pred)[source]¶ Compute the per-example focal loss.
This method simply calls
sparse_categorical_focal_loss()
with the appropriate arguments.Parameters: - y_true (tensor-like, shape (N,)) – Integer class labels.
- y_pred (tensor-like, shape (N, K)) – Either probabilities or logits, depending on the from_logits parameter.
Returns: The per-example focal loss. Reduction to a scalar is handled by this layer’s
__call__()
method.Return type: tf.Tensor
-
classmethod
from_config
(config)¶ Instantiates a Loss from its config (output of get_config()).
Parameters: config – Output of get_config(). Returns: A Loss instance.