santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
10.7 kB
""" EfficientDet Focal, Huber/Smooth L1 loss fns w/ jit support
Based on loss fn in Google's automl EfficientDet repository (Apache 2.0 license).
https://github.com/google/automl/tree/master/efficientdet
Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple
def focal_loss_legacy(logits, targets, alpha: float, gamma: float, normalizer):
"""Compute the focal loss between `logits` and the golden `target` values.
'Legacy focal loss matches the loss used in the official Tensorflow impl for initial
model releases and some time after that. It eventually transitioned to the 'New' loss
defined below.
Focal loss = -(1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
Args:
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
alpha: A float32 scalar multiplying alpha to the loss from positive examples
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: A float32 scalar normalizes the total loss from all examples.
Returns:
loss: A float32 scalar representing normalized total loss.
"""
positive_label_mask = targets == 1.0
cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none')
neg_logits = -1.0 * logits
modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss)
return weighted_loss / normalizer
def new_focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_smoothing: float = 0.01):
"""Compute the focal loss between `logits` and the golden `target` values.
'New' is not the best descriptor, but this focal loss impl matches recent versions of
the official Tensorflow impl of EfficientDet. It has support for label smoothing, however
it is a bit slower, doesn't jit optimize well, and uses more memory.
Focal loss = -(1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
Args:
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
alpha: A float32 scalar multiplying alpha to the loss from positive examples
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: Divide loss by this value.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
loss: A float32 scalar representing normalized total loss.
"""
# compute focal loss multipliers before label smoothing, such that it will not blow up the loss.
pred_prob = logits.sigmoid()
targets = targets.to(logits.dtype)
onem_targets = 1. - targets
p_t = (targets * pred_prob) + (onem_targets * (1. - pred_prob))
alpha_factor = targets * alpha + onem_targets * (1. - alpha)
modulating_factor = (1. - p_t) ** gamma
# apply label smoothing for cross_entropy for each entry.
if label_smoothing > 0.:
targets = targets * (1. - label_smoothing) + .5 * label_smoothing
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
# compute the final loss and return
return (1 / normalizer) * alpha_factor * modulating_factor * ce
def huber_loss(
input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True):
"""
"""
err = input - target
abs_err = err.abs()
quadratic = torch.clamp(abs_err, max=delta)
linear = abs_err - quadratic
loss = 0.5 * quadratic.pow(2) + delta * linear
if weights is not None:
loss *= weights
if size_average:
return loss.mean()
else:
return loss.sum()
def smooth_l1_loss(
input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter
"""
if beta < 1e-5:
# if beta == 0, then torch.where will result in nan gradients when
# the chain rule is applied due to pytorch implementation details
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
# zeros, rather than "no gradient"). To avoid this issue, we define
# small values of beta to be exactly l1 loss.
loss = torch.abs(input - target)
else:
err = torch.abs(input - target)
loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta)
if weights is not None:
loss *= weights
if size_average:
return loss.mean()
else:
return loss.sum()
def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1):
"""Computes box regression loss."""
# delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
normalizer = num_positives * 4.0
mask = box_targets != 0.0
box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False)
return box_loss / normalizer
def one_hot(x, num_classes: int):
# NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out
x_non_neg = (x >= 0).unsqueeze(-1)
onehot = torch.zeros(x.shape + (num_classes,), device=x.device, dtype=torch.float32)
return onehot.scatter(-1, x.unsqueeze(-1) * x_non_neg, 1) * x_non_neg
def loss_fn(
cls_outputs: List[torch.Tensor],
box_outputs: List[torch.Tensor],
cls_targets: List[torch.Tensor],
box_targets: List[torch.Tensor],
num_positives: torch.Tensor,
num_classes: int,
alpha: float,
gamma: float,
delta: float,
box_loss_weight: float,
label_smoothing: float = 0.,
new_focal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes total detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors].
at each feature level (index)
box_outputs: a List with values representing box regression targets in
[batch_size, height, width, num_anchors * 4] at each feature level (index)
cls_targets: groundtruth class targets.
box_targets: groundtrusth box targets.
num_positives: num positive grountruth anchors
Returns:
total_loss: an integer tensor representing total loss reducing from class and box losses from all levels.
cls_loss: an integer tensor representing total class loss.
box_loss: an integer tensor representing total box regression loss.
"""
# Sum all positives in a batch for normalization and avoid zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = (num_positives.sum() + 1.0).float()
levels = len(cls_outputs)
cls_losses = []
box_losses = []
for l in range(levels):
cls_targets_at_level = cls_targets[l]
box_targets_at_level = box_targets[l]
# Onehot encoding for classification labels.
cls_targets_at_level_oh = one_hot(cls_targets_at_level, num_classes)
bs, height, width, _, _ = cls_targets_at_level_oh.shape
cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1)
cls_outputs_at_level = cls_outputs[l].permute(0, 2, 3, 1).float()
if new_focal:
cls_loss = new_focal_loss(
cls_outputs_at_level, cls_targets_at_level_oh,
alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing)
else:
cls_loss = focal_loss_legacy(
cls_outputs_at_level, cls_targets_at_level_oh,
alpha=alpha, gamma=gamma, normalizer=num_positives_sum)
cls_loss = cls_loss.view(bs, height, width, -1, num_classes)
cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1)
cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2))
box_losses.append(_box_loss(
box_outputs[l].permute(0, 2, 3, 1).float(),
box_targets_at_level,
num_positives_sum,
delta=delta))
# Sum per level losses to total loss.
cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)
box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1)
total_loss = cls_loss + box_loss_weight * box_loss
return total_loss, cls_loss, box_loss
loss_jit = torch.jit.script(loss_fn)
class DetectionLoss(nn.Module):
__constants__ = ['num_classes']
def __init__(self, config):
super(DetectionLoss, self).__init__()
self.config = config
self.num_classes = config.num_classes
self.alpha = config.alpha
self.gamma = config.gamma
self.delta = config.delta
self.box_loss_weight = config.box_loss_weight
self.label_smoothing = config.label_smoothing
self.new_focal = config.new_focal
self.use_jit = config.jit_loss
def forward(
self,
cls_outputs: List[torch.Tensor],
box_outputs: List[torch.Tensor],
cls_targets: List[torch.Tensor],
box_targets: List[torch.Tensor],
num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
l_fn = loss_fn
if not torch.jit.is_scripting() and self.use_jit:
# This branch only active if parent / bench itself isn't being scripted
# NOTE: I haven't figured out what to do here wrt to tracing, is it an issue?
l_fn = loss_jit
return l_fn(
cls_outputs, box_outputs, cls_targets, box_targets, num_positives,
num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta,
box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, new_focal=self.new_focal)