File size: 2,690 Bytes
62ef5f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from torch import nn
from typing import List
from src.models.vit.factory import create_vit
from src.models.vit.vit import FeatureTransform
from ...utils import print_num_params
from timm import create_model
from einops import rearrange


class EmbedModel(nn.Module):
    def __init__(self, config, head_out_idx: List[int], n_dim_output=3, device="cuda") -> None:
        super().__init__()
        self.head_out_idx = head_out_idx
        self.n_dim_output = n_dim_output
        self.device = device
        self.vit = create_vit(config).to(self.device)
        self.vit.eval()
        for params in self.vit.parameters():
            params.requires_grad = False
        print_num_params(self.vit)
        print_num_params(self.vit, is_trainable=True)

        if self.n_dim_output == 3:
            self.feature_transformer = FeatureTransform(config["image_size"], config["d_model"]).to(self.device)
            print_num_params(self.feature_transformer)
            print_num_params(self.feature_transformer, is_trainable=True)

    def forward(self, x):
        vit_outputs = self.vit(x, self.head_out_idx, n_dim_output=self.n_dim_output, return_features=True)
        feat0, feat1, feat2, feat3 = vit_outputs[0], vit_outputs[1], vit_outputs[2], vit_outputs[3]
        if self.n_dim_output == 3:
            feat0, feat1, feat2, feat3 = self.feature_transformer(vit_outputs)
        return feat0, feat1, feat2, feat3


class GeneralEmbedModel(nn.Module):
    def __init__(self, pretrained_model="swin-tiny", device="cuda") -> None:
        """
        vit_tiny_patch16_224.augreg_in21k_ft_in1k
        swinv2_cr_tiny_ns_224.sw_in1k
        """
        super().__init__()
        self.device = device
        self.pretrained_model = pretrained_model
        if pretrained_model == "swin-tiny":
            self.pretrained = create_model(
                "swinv2_cr_tiny_ns_224.sw_in1k",
                pretrained=True,
                features_only=True,
                out_indices=[-4, -3, -2, -1],
            ).to(device)
        elif pretrained_model == "swin-small":
            self.pretrained = create_model(
                "swinv2_cr_small_ns_224.sw_in1k",
                pretrained=True,
                features_only=True,
                out_indices=[-4, -3, -2, -1],
            ).to(device)
        else:
            raise NotImplementedError

        self.pretrained.eval()
        self.upsample = nn.Upsample(scale_factor=2)

        for params in self.pretrained.parameters():
            params.requires_grad = False

    def forward(self, x):
        outputs = self.pretrained(x)
        outputs = [self.upsample(feat) for feat in outputs]

        return outputs