# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch import torch.nn as nn from torch import Tensor from mmdet.registry import MODELS from .utils import weighted_loss @weighted_loss def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: """Smooth L1 loss. Args: pred (Tensor): The prediction. target (Tensor): The learning target of the prediction. beta (float, optional): The threshold in the piecewise function. Defaults to 1.0. Returns: Tensor: Calculated loss """ assert beta > 0 if target.numel() == 0: return pred.sum() * 0 assert pred.size() == target.size() diff = torch.abs(pred - target) loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) return loss @weighted_loss def l1_loss(pred: Tensor, target: Tensor) -> Tensor: """L1 loss. Args: pred (Tensor): The prediction. target (Tensor): The learning target of the prediction. Returns: Tensor: Calculated loss """ if target.numel() == 0: return pred.sum() * 0 assert pred.size() == target.size() loss = torch.abs(pred - target) return loss @MODELS.register_module() class SmoothL1Loss(nn.Module): """Smooth L1 loss. Args: beta (float, optional): The threshold in the piecewise function. Defaults to 1.0. reduction (str, optional): The method to reduce the loss. Options are "none", "mean" and "sum". Defaults to "mean". loss_weight (float, optional): The weight of loss. """ def __init__(self, beta: float = 1.0, reduction: str = 'mean', loss_weight: float = 1.0) -> None: super().__init__() self.beta = beta self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred: Tensor, target: Tensor, weight: Optional[Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None, **kwargs) -> Tensor: """Forward function. Args: pred (Tensor): The prediction. target (Tensor): The learning target of the prediction. weight (Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: Tensor: Calculated loss """ if weight is not None and not torch.any(weight > 0): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * smooth_l1_loss( pred, target, weight, beta=self.beta, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_bbox @MODELS.register_module() class L1Loss(nn.Module): """L1 loss. Args: reduction (str, optional): The method to reduce the loss. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of loss. """ def __init__(self, reduction: str = 'mean', loss_weight: float = 1.0) -> None: super().__init__() self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred: Tensor, target: Tensor, weight: Optional[Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None) -> Tensor: """Forward function. Args: pred (Tensor): The prediction. target (Tensor): The learning target of the prediction. weight (Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: Tensor: Calculated loss """ if weight is not None and not torch.any(weight > 0): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * l1_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor) return loss_bbox