Spaces:
Sleeping
Sleeping
from torch import nn | |
from typing import List | |
from src.models.vit.factory import create_vit | |
from src.models.vit.vit import FeatureTransform | |
from ...utils import print_num_params | |
from timm import create_model | |
from einops import rearrange | |
class EmbedModel(nn.Module): | |
def __init__(self, config, head_out_idx: List[int], n_dim_output=3, device="cuda") -> None: | |
super().__init__() | |
self.head_out_idx = head_out_idx | |
self.n_dim_output = n_dim_output | |
self.device = device | |
self.vit = create_vit(config).to(self.device) | |
self.vit.eval() | |
for params in self.vit.parameters(): | |
params.requires_grad = False | |
print_num_params(self.vit) | |
print_num_params(self.vit, is_trainable=True) | |
if self.n_dim_output == 3: | |
self.feature_transformer = FeatureTransform(config["image_size"], config["d_model"]).to(self.device) | |
print_num_params(self.feature_transformer) | |
print_num_params(self.feature_transformer, is_trainable=True) | |
def forward(self, x): | |
vit_outputs = self.vit(x, self.head_out_idx, n_dim_output=self.n_dim_output, return_features=True) | |
feat0, feat1, feat2, feat3 = vit_outputs[0], vit_outputs[1], vit_outputs[2], vit_outputs[3] | |
if self.n_dim_output == 3: | |
feat0, feat1, feat2, feat3 = self.feature_transformer(vit_outputs) | |
return feat0, feat1, feat2, feat3 | |
class GeneralEmbedModel(nn.Module): | |
def __init__(self, pretrained_model="swin-tiny", device="cuda") -> None: | |
""" | |
vit_tiny_patch16_224.augreg_in21k_ft_in1k | |
swinv2_cr_tiny_ns_224.sw_in1k | |
""" | |
super().__init__() | |
self.device = device | |
self.pretrained_model = pretrained_model | |
if pretrained_model == "swin-tiny": | |
self.pretrained = create_model( | |
"swinv2_cr_tiny_ns_224.sw_in1k", | |
pretrained=True, | |
features_only=True, | |
out_indices=[-4, -3, -2, -1], | |
).to(device) | |
elif pretrained_model == "swin-small": | |
self.pretrained = create_model( | |
"swinv2_cr_small_ns_224.sw_in1k", | |
pretrained=True, | |
features_only=True, | |
out_indices=[-4, -3, -2, -1], | |
).to(device) | |
else: | |
raise NotImplementedError | |
self.pretrained.eval() | |
self.upsample = nn.Upsample(scale_factor=2) | |
for params in self.pretrained.parameters(): | |
params.requires_grad = False | |
def forward(self, x): | |
outputs = self.pretrained(x) | |
outputs = [self.upsample(feat) for feat in outputs] | |
return outputs | |