File size: 2,347 Bytes
99269d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import torch.nn as nn
class Loss(nn.modules.loss._Loss):
"""Inherit this class to implement custom loss."""
def __init__(self, **kwargs):
super(Loss, self).__init__(**kwargs)
class AdditiveMarginSoftmaxLoss(Loss):
"""Computes Additive Margin Softmax (CosFace) Loss
Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition
args:
scale: scale value for cosine angle
margin: margin value added to cosine angle
"""
def __init__(self, scale=30.0, margin=0.2):
super().__init__()
self.eps = 1e-7
self.scale = scale
self.margin = margin
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
# Extract the logits corresponding to the true class
logits_target = logits[torch.arange(logits.size(0)), labels] # Faster indexing
numerator = self.scale * (logits_target - self.margin) # Apply additive margin
# Exclude the target logits from denominator calculation
logits.scatter_(1, labels.unsqueeze(1), float('-inf')) # Mask target class
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1)
# Compute final loss
loss = -torch.log(torch.exp(numerator) / denominator)
return loss.mean()
class AdditiveAngularMarginSoftmaxLoss(Loss):
"""Computes Additive Angular Margin Softmax (ArcFace) Loss
Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
Args:
scale: scale value for cosine angle
margin: margin value added to cosine angle
"""
def __init__(self, scale=20.0, margin=1.35):
super().__init__()
self.eps = 1e-7
self.scale = scale
self.margin = margin
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
numerator = self.scale * torch.cos(
torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
+ self.margin
)
excl = torch.cat(
[torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
)
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
L = numerator - torch.log(denominator)
return -torch.mean(L) |