import torch from torch import Tensor def beta_smooth_l1_loss(input: Tensor, target: Tensor, beta: float) -> Tensor: diff = torch.abs(input - target) loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) loss = loss.sum() / (input.numel() + 1e-8) return loss