Skip to content

Commit 9935b0e

Browse files
author
Kcstring
committed
Fix KL divergence with zero true labels
1 parent 791deb4 commit 9935b0e

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

machine_learning/loss_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,11 +655,16 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
655655
Traceback (most recent call last):
656656
...
657657
ValueError: Input arrays must have the same length.
658+
>>> true_labels = np.array([0.0, 0.5, 0.5])
659+
>>> predicted_probs = np.array([0.2, 0.3, 0.5])
660+
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
661+
0.25541281188299536
658662
"""
659663
if len(y_true) != len(y_pred):
660664
raise ValueError("Input arrays must have the same length.")
661665

662-
kl_loss = y_true * np.log(y_true / y_pred)
666+
non_zero = y_true != 0
667+
kl_loss = y_true[non_zero] * np.log(y_true[non_zero] / y_pred[non_zero])
663668
return np.sum(kl_loss)
664669

665670

0 commit comments

Comments
 (0)