focal_loss
.BinaryFocalLoss¶
-
class
focal_loss.
BinaryFocalLoss
(gamma, *, pos_weight=None, from_logits=False, label_smoothing=None, **kwargs)[source]¶ Bases:
tensorflow.python.keras.losses.Loss
Focal loss function for binary classification.
This loss function generalizes binary cross-entropy by introducing a hyperparameter called the focusing parameter that allows hard-to-classify examples to be penalized more heavily relative to easy-to-classify examples.
This class is a wrapper around
binary_focal_loss
. See the documentation there for details about this loss function.Parameters: - gamma (float) – The focusing parameter \(\gamma\). Must be non-negative.
- pos_weight (float, optional) – The coefficient \(\alpha\) to use on the positive examples. Must be non-negative.
- from_logits (bool, optional) – Whether model prediction will be logits or probabilities.
- label_smoothing (float, optional) – Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary ground truth labels are squeezed toward 0.5, with larger values of label_smoothing leading to label values closer to 0.5.
- **kwargs (keyword arguments) – Other keyword arguments for
tf.keras.losses.Loss
(e.g., name or reduction).
Examples
An instance of this class is a callable that takes a tensor of binary ground truth labels y_true and a tensor of model predictions y_pred and returns a scalar tensor obtained by reducing the per-example focal loss (the default reduction is a batch-wise average).
>>> from focal_loss import BinaryFocalLoss >>> loss_func = BinaryFocalLoss(gamma=2) >>> loss = loss_func([0, 1, 1], [0.1, 0.7, 0.9]) # A scalar tensor >>> print(f'Mean focal loss: {loss.numpy():.3f}') Mean focal loss: 0.011
Use this class in the
tf.keras
API like any other binary classification loss function class found intf.keras.losses
(e.g.,tf.keras.losses.BinaryCrossentropy
:# Typical usage model = tf.keras.Model(...) model.compile( optimizer=..., loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss metrics=..., ) history = model.fit(...)
See also
binary_focal_loss()
- The function that performs the focal loss computation, taking a label tensor and a prediction tensor and outputting a loss.
-
call
(y_true, y_pred)[source]¶ Compute the per-example focal loss.
This method simply calls
binary_focal_loss()
with the appropriate arguments.Parameters: - y_true (tensor-like) – Binary (0 or 1) class labels.
- y_pred (tensor-like) – Either probabilities for the positive class or logits for the positive class, depending on the from_logits attribute. The shapes of y_true and y_pred should be broadcastable.
Returns: The per-example focal loss. Reduction to a scalar is handled by this layer’s
__call__()
method.Return type: tf.Tensor
-
classmethod
from_config
(config)¶ Instantiates a Loss from its config (output of get_config()).
Parameters: config – Output of get_config(). Returns: A Loss instance.