awsaf49's picture
Initial Commit
3f50570
import torch
import torch.nn as nn
import torch.nn.functional as F
class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
def __init__(self, label_smoothing=0.0, **kwargs):
super(BCEWithLogitsLoss, self).__init__(**kwargs)
self.label_smoothing = label_smoothing
def forward(self, input, target):
if self.label_smoothing:
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
return super(BCEWithLogitsLoss, self).forward(input, target)
class SigmoidFocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, label_smoothing=0.0, reduction="mean"):
"""
Args:
alpha (float): Weighting factor in range (0,1) to balance positive vs negative examples.
gamma (float): Focusing parameter to reduce the relative loss for well-classified examples.
label_smoothing (float): Label smoothing factor to reduce the confidence of the true label.
reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of elements in the output,
'sum': the output will be summed.
"""
super(SigmoidFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, input, target):
"""
Args:
input (Tensor): Predicted logits for each example.
target (Tensor): Ground truth binary labels (0 or 1) for each example.
"""
if self.label_smoothing:
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
p = torch.sigmoid(input)
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none")
p_t = p * target + (1 - p) * (1 - target)
loss = ce_loss * ((1 - p_t) ** self.gamma)
if self.alpha >= 0:
alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
loss = alpha_t * loss
# Check reduction option and return loss accordingly
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss