Focal LossΒΆ

Python Version PyPI Package Version Last Commit GitHub Actions Build Status Code Coverage Documentation Status License

TensorFlow implementation of focal loss: a loss function generalizing binary cross-entropy loss that penalizes hard-to-classify examples.

Focal loss plot

The focal_loss package provides a function binary_focal_loss() and a class BinaryFocalLoss that can be used as stand-in replacements for tf.keras.losses functions and classes, respectively.

# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss

model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)