duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
2.69 kB
from torch import nn
from typing import List
from src.models.vit.factory import create_vit
from src.models.vit.vit import FeatureTransform
from ...utils import print_num_params
from timm import create_model
from einops import rearrange
class EmbedModel(nn.Module):
def __init__(self, config, head_out_idx: List[int], n_dim_output=3, device="cuda") -> None:
super().__init__()
self.head_out_idx = head_out_idx
self.n_dim_output = n_dim_output
self.device = device
self.vit = create_vit(config).to(self.device)
self.vit.eval()
for params in self.vit.parameters():
params.requires_grad = False
print_num_params(self.vit)
print_num_params(self.vit, is_trainable=True)
if self.n_dim_output == 3:
self.feature_transformer = FeatureTransform(config["image_size"], config["d_model"]).to(self.device)
print_num_params(self.feature_transformer)
print_num_params(self.feature_transformer, is_trainable=True)
def forward(self, x):
vit_outputs = self.vit(x, self.head_out_idx, n_dim_output=self.n_dim_output, return_features=True)
feat0, feat1, feat2, feat3 = vit_outputs[0], vit_outputs[1], vit_outputs[2], vit_outputs[3]
if self.n_dim_output == 3:
feat0, feat1, feat2, feat3 = self.feature_transformer(vit_outputs)
return feat0, feat1, feat2, feat3
class GeneralEmbedModel(nn.Module):
def __init__(self, pretrained_model="swin-tiny", device="cuda") -> None:
"""
vit_tiny_patch16_224.augreg_in21k_ft_in1k
swinv2_cr_tiny_ns_224.sw_in1k
"""
super().__init__()
self.device = device
self.pretrained_model = pretrained_model
if pretrained_model == "swin-tiny":
self.pretrained = create_model(
"swinv2_cr_tiny_ns_224.sw_in1k",
pretrained=True,
features_only=True,
out_indices=[-4, -3, -2, -1],
).to(device)
elif pretrained_model == "swin-small":
self.pretrained = create_model(
"swinv2_cr_small_ns_224.sw_in1k",
pretrained=True,
features_only=True,
out_indices=[-4, -3, -2, -1],
).to(device)
else:
raise NotImplementedError
self.pretrained.eval()
self.upsample = nn.Upsample(scale_factor=2)
for params in self.pretrained.parameters():
params.requires_grad = False
def forward(self, x):
outputs = self.pretrained(x)
outputs = [self.upsample(feat) for feat in outputs]
return outputs