focal_loss.BinaryFocalLoss

class focal_loss.BinaryFocalLoss(gamma, *, pos_weight=None, from_logits=False, label_smoothing=None, **kwargs)[source]

Bases: tensorflow.python.keras.losses.Loss

Focal loss function for binary classification.

This loss function generalizes binary cross-entropy by introducing a hyperparameter called the focusing parameter that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.

This class is a wrapper around binary_focal_loss. See the documentation there for details about this loss function.

Parameters:
  • gamma (float) – The focusing parameter \(\gamma\). Must be non-negative.
  • pos_weight (float, optional) – The coefficient \(\alpha\) to use on the positive examples. Must be non-negative.
  • from_logits (bool, optional) – Whether model prediction will be logits or probabilities.
  • label_smoothing (float, optional) – Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are squeezed toward 0.5, with larger values of label_smoothing leading to label values closer to 0.5.
  • **kwargs (keyword arguments) – Other keyword arguments for tf.keras.losses.Loss (e.g., name or reduction).

Examples

An instance of this class is a callable that takes a tensor of binary ground truth labels y_true and a tensor of model predictions y_pred and returns a scalar tensor obtained by reducing the per-example focal loss (the default reduction is a batch-wise average).

>>> from focal_loss import BinaryFocalLoss
>>> loss_func = BinaryFocalLoss(gamma=2)
>>> loss = loss_func([0, 1, 1], [0.1, 0.7, 0.9])  # A scalar tensor
>>> print(f'Mean focal loss: {loss.numpy():.3f}')
Mean focal loss: 0.011

Use this class in the tf.keras API like any other binary classification loss function class found in tf.keras.losses (e.g., tf.keras.losses.BinaryCrossentropy:

# Typical usage
model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

See also

binary_focal_loss()
The function that performs the focal loss computation, taking a label tensor and a prediction tensor and outputting a loss.
call(y_true, y_pred)[source]

Compute the per-example focal loss.

This method simply calls binary_focal_loss() with the appropriate arguments.

Parameters:
  • y_true (tensor-like) – Binary (0 or 1) class labels.
  • y_pred (tensor-like) – Either probabilities for the positive class or logits for the positive class, depending on the from_logits attribute. The shapes of y_true and y_pred should be broadcastable.
Returns:

The per-example focal loss. Reduction to a scalar is handled by this layer’s __call__() method.

Return type:

tf.Tensor

classmethod from_config(config)

Instantiates a Loss from its config (output of get_config()).

Parameters:config – Output of get_config().
Returns:A Loss instance.
get_config()[source]

Returns the config of the layer.

A layer config is a Python dictionary containing the configuration of a layer. The same layer can be re-instantiated later (without its trained weights) from this configuration.

Returns:This layer’s config.
Return type:dict