Focal LossΒΆ

Python Version PyPI Package Version Last Commit GitHub Actions Build Status Code Coverage Documentation Status License

TensorFlow implementation of focal loss: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.

Focal loss plot

The focal_loss package provides functions and classes that can be used as off-the-shelf replacements for tf.keras.losses functions and classes, respectively.

# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss

model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

The focal_loss package includes the functions

and wrapper classes