File size: 299 Bytes
d6def08 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
from torch import Tensor
def beta_smooth_l1_loss(input: Tensor, target: Tensor, beta: float) -> Tensor:
diff = torch.abs(input - target)
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
loss = loss.sum() / (input.numel() + 1e-8)
return loss
|