Spaces:
Sleeping
Sleeping
""" EfficientDet Focal, Huber/Smooth L1 loss fns w/ jit support | |
Based on loss fn in Google's automl EfficientDet repository (Apache 2.0 license). | |
https://github.com/google/automl/tree/master/efficientdet | |
Copyright 2020 Ross Wightman | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Optional, List, Tuple | |
def focal_loss_legacy(logits, targets, alpha: float, gamma: float, normalizer): | |
"""Compute the focal loss between `logits` and the golden `target` values. | |
'Legacy focal loss matches the loss used in the official Tensorflow impl for initial | |
model releases and some time after that. It eventually transitioned to the 'New' loss | |
defined below. | |
Focal loss = -(1-pt)^gamma * log(pt) | |
where pt is the probability of being classified to the true class. | |
Args: | |
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions]. | |
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions]. | |
alpha: A float32 scalar multiplying alpha to the loss from positive examples | |
and (1-alpha) to the loss from negative examples. | |
gamma: A float32 scalar modulating loss from hard and easy examples. | |
normalizer: A float32 scalar normalizes the total loss from all examples. | |
Returns: | |
loss: A float32 scalar representing normalized total loss. | |
""" | |
positive_label_mask = targets == 1.0 | |
cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none') | |
neg_logits = -1.0 * logits | |
modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits))) | |
loss = modulator * cross_entropy | |
weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss) | |
return weighted_loss / normalizer | |
def new_focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_smoothing: float = 0.01): | |
"""Compute the focal loss between `logits` and the golden `target` values. | |
'New' is not the best descriptor, but this focal loss impl matches recent versions of | |
the official Tensorflow impl of EfficientDet. It has support for label smoothing, however | |
it is a bit slower, doesn't jit optimize well, and uses more memory. | |
Focal loss = -(1-pt)^gamma * log(pt) | |
where pt is the probability of being classified to the true class. | |
Args: | |
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions]. | |
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions]. | |
alpha: A float32 scalar multiplying alpha to the loss from positive examples | |
and (1-alpha) to the loss from negative examples. | |
gamma: A float32 scalar modulating loss from hard and easy examples. | |
normalizer: Divide loss by this value. | |
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. | |
Returns: | |
loss: A float32 scalar representing normalized total loss. | |
""" | |
# compute focal loss multipliers before label smoothing, such that it will not blow up the loss. | |
pred_prob = logits.sigmoid() | |
targets = targets.to(logits.dtype) | |
onem_targets = 1. - targets | |
p_t = (targets * pred_prob) + (onem_targets * (1. - pred_prob)) | |
alpha_factor = targets * alpha + onem_targets * (1. - alpha) | |
modulating_factor = (1. - p_t) ** gamma | |
# apply label smoothing for cross_entropy for each entry. | |
if label_smoothing > 0.: | |
targets = targets * (1. - label_smoothing) + .5 * label_smoothing | |
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') | |
# compute the final loss and return | |
return (1 / normalizer) * alpha_factor * modulating_factor * ce | |
def huber_loss( | |
input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True): | |
""" | |
""" | |
err = input - target | |
abs_err = err.abs() | |
quadratic = torch.clamp(abs_err, max=delta) | |
linear = abs_err - quadratic | |
loss = 0.5 * quadratic.pow(2) + delta * linear | |
if weights is not None: | |
loss *= weights | |
if size_average: | |
return loss.mean() | |
else: | |
return loss.sum() | |
def smooth_l1_loss( | |
input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True): | |
""" | |
very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter | |
""" | |
if beta < 1e-5: | |
# if beta == 0, then torch.where will result in nan gradients when | |
# the chain rule is applied due to pytorch implementation details | |
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of | |
# zeros, rather than "no gradient"). To avoid this issue, we define | |
# small values of beta to be exactly l1 loss. | |
loss = torch.abs(input - target) | |
else: | |
err = torch.abs(input - target) | |
loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta) | |
if weights is not None: | |
loss *= weights | |
if size_average: | |
return loss.mean() | |
else: | |
return loss.sum() | |
def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1): | |
"""Computes box regression loss.""" | |
# delta is typically around the mean value of regression target. | |
# for instances, the regression targets of 512x512 input with 6 anchors on | |
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2]. | |
normalizer = num_positives * 4.0 | |
mask = box_targets != 0.0 | |
box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False) | |
return box_loss / normalizer | |
def one_hot(x, num_classes: int): | |
# NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out | |
x_non_neg = (x >= 0).unsqueeze(-1) | |
onehot = torch.zeros(x.shape + (num_classes,), device=x.device, dtype=torch.float32) | |
return onehot.scatter(-1, x.unsqueeze(-1) * x_non_neg, 1) * x_non_neg | |
def loss_fn( | |
cls_outputs: List[torch.Tensor], | |
box_outputs: List[torch.Tensor], | |
cls_targets: List[torch.Tensor], | |
box_targets: List[torch.Tensor], | |
num_positives: torch.Tensor, | |
num_classes: int, | |
alpha: float, | |
gamma: float, | |
delta: float, | |
box_loss_weight: float, | |
label_smoothing: float = 0., | |
new_focal: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Computes total detection loss. | |
Computes total detection loss including box and class loss from all levels. | |
Args: | |
cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors]. | |
at each feature level (index) | |
box_outputs: a List with values representing box regression targets in | |
[batch_size, height, width, num_anchors * 4] at each feature level (index) | |
cls_targets: groundtruth class targets. | |
box_targets: groundtrusth box targets. | |
num_positives: num positive grountruth anchors | |
Returns: | |
total_loss: an integer tensor representing total loss reducing from class and box losses from all levels. | |
cls_loss: an integer tensor representing total class loss. | |
box_loss: an integer tensor representing total box regression loss. | |
""" | |
# Sum all positives in a batch for normalization and avoid zero | |
# num_positives_sum, which would lead to inf loss during training | |
num_positives_sum = (num_positives.sum() + 1.0).float() | |
levels = len(cls_outputs) | |
cls_losses = [] | |
box_losses = [] | |
for l in range(levels): | |
cls_targets_at_level = cls_targets[l] | |
box_targets_at_level = box_targets[l] | |
# Onehot encoding for classification labels. | |
cls_targets_at_level_oh = one_hot(cls_targets_at_level, num_classes) | |
bs, height, width, _, _ = cls_targets_at_level_oh.shape | |
cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1) | |
cls_outputs_at_level = cls_outputs[l].permute(0, 2, 3, 1).float() | |
if new_focal: | |
cls_loss = new_focal_loss( | |
cls_outputs_at_level, cls_targets_at_level_oh, | |
alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing) | |
else: | |
cls_loss = focal_loss_legacy( | |
cls_outputs_at_level, cls_targets_at_level_oh, | |
alpha=alpha, gamma=gamma, normalizer=num_positives_sum) | |
cls_loss = cls_loss.view(bs, height, width, -1, num_classes) | |
cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1) | |
cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2)) | |
box_losses.append(_box_loss( | |
box_outputs[l].permute(0, 2, 3, 1).float(), | |
box_targets_at_level, | |
num_positives_sum, | |
delta=delta)) | |
# Sum per level losses to total loss. | |
cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1) | |
box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1) | |
total_loss = cls_loss + box_loss_weight * box_loss | |
return total_loss, cls_loss, box_loss | |
loss_jit = torch.jit.script(loss_fn) | |
class DetectionLoss(nn.Module): | |
__constants__ = ['num_classes'] | |
def __init__(self, config): | |
super(DetectionLoss, self).__init__() | |
self.config = config | |
self.num_classes = config.num_classes | |
self.alpha = config.alpha | |
self.gamma = config.gamma | |
self.delta = config.delta | |
self.box_loss_weight = config.box_loss_weight | |
self.label_smoothing = config.label_smoothing | |
self.new_focal = config.new_focal | |
self.use_jit = config.jit_loss | |
def forward( | |
self, | |
cls_outputs: List[torch.Tensor], | |
box_outputs: List[torch.Tensor], | |
cls_targets: List[torch.Tensor], | |
box_targets: List[torch.Tensor], | |
num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
l_fn = loss_fn | |
if not torch.jit.is_scripting() and self.use_jit: | |
# This branch only active if parent / bench itself isn't being scripted | |
# NOTE: I haven't figured out what to do here wrt to tracing, is it an issue? | |
l_fn = loss_jit | |
return l_fn( | |
cls_outputs, box_outputs, cls_targets, box_targets, num_positives, | |
num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta, | |
box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, new_focal=self.new_focal) | |