Focal LossΒΆ
TensorFlow implementation of focal loss: a loss function generalizing binary cross-entropy loss that penalizes hard-to-classify examples.
The focal_loss
package provides a function
binary_focal_loss()
and a class
BinaryFocalLoss
that can be used as stand-in replacements
for tf.keras.losses
functions and classes, respectively.
# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)