focal_loss.SparseCategoricalFocalLoss

class focal_loss.SparseCategoricalFocalLoss(gamma, class_weight: Optional[Any] = None, from_logits: bool = False, **kwargs)[source]

Bases: tensorflow.python.keras.losses.Loss

Focal loss function for multiclass classification with integer labels.

This loss function generalizes multiclass softmax cross-entropy by introducing a hyperparameter \(\gamma\) (gamma), called the focusing parameter, that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.

This class is a wrapper around sparse_categorical_focal_loss. See the documentation there for details about this loss function.

Parameters:
  • gamma (float or tensor-like of shape (K,)) – The focusing parameter \(\gamma\). Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative. This can be a one-dimensional tensor, in which case it specifies a focusing parameter for each class.
  • class_weight (tensor-like of shape (K,)) – Weighting factor for each of the \(k\) classes. If not specified, then all classes are weighted equally.
  • from_logits (bool, optional) – Whether model prediction will be logits or probabilities.
  • **kwargs (keyword arguments) – Other keyword arguments for tf.keras.losses.Loss (e.g., name or reduction).

Examples

An instance of this class is a callable that takes a rank-one tensor of integer class labels y_true and a tensor of model predictions y_pred and returns a scalar tensor obtained by reducing the per-example focal loss (the default reduction is a batch-wise average).

>>> from focal_loss import SparseCategoricalFocalLoss
>>> loss_func = SparseCategoricalFocalLoss(gamma=2)
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss_func(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.040919524>

Use this class in the tf.keras API like any other multiclass classification loss function class that accepts integer labels found in tf.keras.losses (e.g., tf.keras.losses.SparseCategoricalCrossentropy:

# Typical usage
model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=SparseCategoricalFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

See also

sparse_categorical_focal_loss()
The function that performs the focal loss computation, taking a label tensor and a prediction tensor and outputting a loss.
call(y_true, y_pred)[source]

Compute the per-example focal loss.

This method simply calls sparse_categorical_focal_loss() with the appropriate arguments.

Parameters:
  • y_true (tensor-like, shape (N,)) – Integer class labels.
  • y_pred (tensor-like, shape (N, K)) – Either probabilities or logits, depending on the from_logits parameter.
Returns:

The per-example focal loss. Reduction to a scalar is handled by this layer’s __call__() method.

Return type:

tf.Tensor

classmethod from_config(config)

Instantiates a Loss from its config (output of get_config()).

Parameters:config – Output of get_config().
Returns:A Loss instance.
get_config()[source]

Returns the config of the layer.

A layer config is a Python dictionary containing the configuration of a layer. The same layer can be re-instantiated later (without its trained weights) from this configuration.

Returns:This layer’s config.
Return type:dict