File size: 2,347 Bytes
99269d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn


class Loss(nn.modules.loss._Loss):
    """Inherit this class to implement custom loss."""

    def __init__(self, **kwargs):
        super(Loss, self).__init__(**kwargs)


class AdditiveMarginSoftmaxLoss(Loss):
    """Computes Additive Margin Softmax (CosFace) Loss
    
    Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition

    args:
    scale: scale value for cosine angle
    margin: margin value added to cosine angle 
    """

    def __init__(self, scale=30.0, margin=0.2):
        super().__init__()

        self.eps = 1e-7
        self.scale = scale
        self.margin = margin

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        # Extract the logits corresponding to the true class
        logits_target = logits[torch.arange(logits.size(0)), labels]  # Faster indexing
        numerator = self.scale * (logits_target - self.margin)  # Apply additive margin
        # Exclude the target logits from denominator calculation
        logits.scatter_(1, labels.unsqueeze(1), float('-inf'))  # Mask target class
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1)
        # Compute final loss
        loss = -torch.log(torch.exp(numerator) / denominator)
        return loss.mean()


class AdditiveAngularMarginSoftmaxLoss(Loss):
    """Computes Additive Angular Margin Softmax (ArcFace) Loss
    
    Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
    
    Args:
    scale: scale value for cosine angle
    margin: margin value added to cosine angle 
    """

    def __init__(self, scale=20.0, margin=1.35):
        super().__init__()

        self.eps = 1e-7
        self.scale = scale
        self.margin = margin

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        numerator = self.scale * torch.cos(
            torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
            + self.margin
        )
        excl = torch.cat(
            [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
        )
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
        L = numerator - torch.log(denominator)
        return -torch.mean(L)