File size: 6,269 Bytes
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from torch import Tensor
from torch.nn import Parameter
from typing import Any
from layers.independent_mlp import IndependentMLPs


# Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer
class IndividualLandmarkConvNext(torch.nn.Module):
    def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
                 num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3,
                 modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
                 gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
                 classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
        super().__init__()

        self.num_landmarks = num_landmarks
        self.num_classes = num_classes
        self.noise_variance = noise_variance
        self.stem = init_model.stem
        self.stages = init_model.stages
        self.feature_dim = sl_channels + fl_channels
        self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
        self.gumbel_softmax = gumbel_softmax
        self.gumbel_softmax_temperature = gumbel_softmax_temperature
        self.gumbel_softmax_hard = gumbel_softmax_hard
        self.modulation_type = modulation_type
        if modulation_type == "layer_norm":
            self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
        elif modulation_type == "original":
            self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
        elif modulation_type == "parallel_mlp":
            self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
                                              num_lin_layers=1, act_layer=True, bias=True)
        elif modulation_type == "parallel_mlp_no_bias":
            self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
                                              num_lin_layers=1, act_layer=True, bias=False)
        elif modulation_type == "parallel_mlp_no_act":
            self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
                                              num_lin_layers=1, act_layer=False, bias=True)
        elif modulation_type == "parallel_mlp_no_act_no_bias":
            self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
                                              num_lin_layers=1, act_layer=False, bias=False)
        elif modulation_type == "none":
            self.modulation = torch.nn.Identity()
        else:
            raise ValueError("modulation_type not implemented")
        self.modulation_orth = modulation_orth
        self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
        self.classifier_type = classifier_type
        if classifier_type == "independent_mlp":
            self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
                                                      num_lin_layers=1, act_layer=False, out_dim=num_classes,
                                                      bias=False, stack_dim=1)
        elif classifier_type == "linear":
            self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
                                                      bias=False)
        else:
            raise ValueError("classifier_type not implemented")

    def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
        # Pretrained ConvNeXt part of the model
        x = self.stem(x)
        x = self.stages[0](x)
        x = self.stages[1](x)
        l3 = self.stages[2](x)
        x = self.stages[3](l3)
        x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
        x = torch.cat((x, l3), dim=1)

        # Compute per landmark attention maps
        # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
        batch_size = x.shape[0]
        ab = self.fc_landmarks(x)
        b_sq = x.pow(2).sum(1, keepdim=True)
        b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
        a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
                                                                          x.shape[-1]).contiguous()
        a_sq = a_sq.permute(1, 0, 2, 3).contiguous()

        dist = b_sq - 2 * ab + a_sq
        maps = -dist

        # Softmax so that the attention maps for each pixel add up to 1
        if self.gumbel_softmax:
            maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
                                                      hard=self.gumbel_softmax_hard)  # [B, num_landmarks + 1, H, W]
        else:
            maps = torch.nn.functional.softmax(maps, dim=1)  # [B, num_landmarks + 1, H, W]

        # Use maps to get weighted average features per landmark
        all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
        if self.noise_variance > 0.0:
            all_features += torch.randn_like(all_features,
                                             device=all_features.device) * x.std().detach() * self.noise_variance

        # Modulate the features
        if self.modulation_type == "original":
            all_features_mod = all_features * self.modulation
        else:
            all_features_mod = self.modulation(all_features)

        # Classification based on the landmark features
        scores = self.fc_class_landmarks(
            self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
                                                                                                           1).contiguous()
        if self.modulation_orth:
            return all_features_mod, maps, scores, dist
        else:
            return all_features, maps, scores, dist