|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from lavila.models.openai_clip import load as load_openai_clip |
|
from lavila.models.openai_model import QuickGELU, Transformer |
|
from lavila.models.timesformer import SpaceTimeTransformer |
|
from lavila.models.utils import remap_keys, rsetattr |
|
from lavila.models.prompt_tuning import PromptLearner |
|
|
|
|
|
class CLIP(nn.Module): |
|
def __init__(self, |
|
cfg, |
|
embed_dim: int, |
|
|
|
vision_width: int, |
|
vision_model: nn.Module, |
|
|
|
context_length: int, |
|
vocab_size: int, |
|
transformer_width: int, |
|
transformer_heads: int, |
|
transformer_layers: int, |
|
tempearture_init=0.07, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.context_length = context_length |
|
self.vision_width = vision_width |
|
self.tune_bias = cfg.get('tune_bias', False) |
|
self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False) |
|
self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False) |
|
|
|
self.visual = vision_model |
|
self.t_step = cfg.get('t_step', self.visual.num_frames) |
|
txt_prompt_cfg = cfg.get('text_prompt', {}) |
|
self.n_ctx = txt_prompt_cfg.get('n_ctx', 0) |
|
self.txt_use_bank = txt_prompt_cfg.get('use_bank', False) |
|
if self.txt_use_bank: |
|
self.transformer = Transformer( |
|
width=transformer_width, |
|
layers=transformer_layers, |
|
heads=transformer_heads, |
|
attn_mask=self.build_attention_mask(), |
|
prompt_cfg=txt_prompt_cfg, |
|
prompt_learner=PromptLearner(transformer_width, self.n_ctx), |
|
prompt_generator=self.visual.prompt_generator |
|
) |
|
else: |
|
self.transformer = Transformer( |
|
width=transformer_width, |
|
layers=transformer_layers, |
|
heads=transformer_heads, |
|
attn_mask=self.build_attention_mask(), |
|
prompt_cfg=txt_prompt_cfg, |
|
prompt_learner=PromptLearner(transformer_width, self.n_ctx) |
|
) |
|
|
|
self.vocab_size = vocab_size |
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|
self.ln_final = nn.LayerNorm(transformer_width) |
|
|
|
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) |
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|
print("=> initialize initial temperature with {}".format(tempearture_init)) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init)) |
|
|
|
self.initialize_parameters() |
|
|
|
freeze_list = [] |
|
if self.freeze_vis_backbone: |
|
print("=> Freeze visual backbone") |
|
freeze_list += self.visual.param_list + [self.image_projection] |
|
|
|
if self.freeze_txt_backbone: |
|
print("=> Freeze text backbone") |
|
if self.tune_bias: |
|
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n] |
|
freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n] |
|
else: |
|
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n] |
|
freeze_list += list(self.ln_final.parameters()) |
|
freeze_list += list(self.token_embedding.parameters()) |
|
freeze_list += [self.positional_embedding] + [self.text_projection] |
|
|
|
for p in freeze_list: |
|
p.requires_grad = False |
|
|
|
|
|
if self.n_ctx > 0: |
|
if self.txt_use_bank: |
|
prompt_dim = self.visual.prompt_dim |
|
if prompt_dim != transformer_width: |
|
self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False) |
|
else: |
|
self.transformer.prompt_inproj = nn.Identity() |
|
self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False) |
|
nn.init.kaiming_normal_( |
|
self.transformer.prompt_outproj.weight, a=0, mode='fan_out') |
|
|
|
params_to_update = [n for n, m in self.named_parameters() if m.requires_grad] |
|
num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad]) |
|
num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad]) |
|
print("=> Params to update: {}".format(params_to_update)) |
|
print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params)) |
|
|
|
def initialize_parameters(self): |
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|
nn.init.normal_(self.positional_embedding, std=0.01) |
|
|
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
|
attn_std = self.transformer.width ** -0.5 |
|
fc_std = (2 * self.transformer.width) ** -0.5 |
|
for block in self.transformer.resblocks: |
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|
|
|
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) |
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0): |
|
x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) |
|
|
|
if isinstance(x, list): |
|
assert len(x) == 1 |
|
x = x[0] |
|
if apply_project: |
|
x = x @ self.image_projection |
|
|
|
return x, ps_loss |
|
|
|
def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0): |
|
x = self.token_embedding(text) |
|
B = x.shape[0] |
|
eot = text.argmax(dim=-1) |
|
|
|
x = x.permute(1, 0, 2) |
|
x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot) |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x) |
|
|
|
|
|
|
|
x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection |
|
|
|
return x, ps_loss |
|
|
|
def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0): |
|
image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) |
|
text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) |
|
|
|
if norm_embed: |
|
image_embed = F.normalize(image_embed, dim=-1) |
|
text_embed = F.normalize(text_embed, dim=-1) |
|
return {'image_embed': image_embed, |
|
'text_embed': text_embed, |
|
'logit_scale': self.logit_scale.exp(), |
|
'ps_loss': ps_loss_img + ps_loss_txt} |
|
|
|
def train(self, mode=True): |
|
if not isinstance(mode, bool): |
|
raise ValueError("training mode is expected to be boolean") |
|
self.training = mode |
|
for m in self.modules(): |
|
m.training = mode |
|
|
|
if mode: |
|
if self.freeze_vis_backbone and not self.tune_bias: |
|
for n, m in self.visual.named_modules(): |
|
if 'prompt' not in n: |
|
m.training = False |
|
|
|
if self.freeze_txt_backbone and not self.tune_bias: |
|
for n, m in self.transformer.named_modules(): |
|
if 'prompt' not in n: |
|
m.training = False |
|
|
|
self.token_embedding.training = False |
|
self.ln_final.training = False |
|
|
|
|
|
def CLIP_OPENAI_TIMESFORMER_BASE( |
|
num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07, |
|
project_embed_dim=256, **kwargs |
|
): |
|
cfg = kwargs.pop('model_cfg', {}) |
|
vision_model = SpaceTimeTransformer( |
|
num_frames=num_frames, |
|
time_init='zeros', |
|
attention_style='frozen-in-time', |
|
ln_pre=True, |
|
act_layer=QuickGELU, |
|
is_tanh_gating=timesformer_gated_xattn, |
|
drop_path_rate=cfg.get('drop_path_rate', 0), |
|
tune_bias=cfg.get('tune_bias', False), |
|
prompt_cfg=cfg.get('visual_prompt', {}) |
|
) |
|
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') |
|
print("=> Loading CLIP (ViT-B/16) weights") |
|
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) |
|
res = vision_model.load_state_dict(remapped_state_dict, strict=False) |
|
print(res) |
|
|
|
vision_model.head = nn.Identity() |
|
vision_model.pre_logits = nn.Identity() |
|
vision_model.fc = nn.Identity() |
|
model = CLIP( |
|
cfg, |
|
embed_dim=project_embed_dim, |
|
vision_width=768, |
|
vision_model=vision_model, |
|
context_length=77, |
|
vocab_size=49408, |
|
transformer_width=512, |
|
transformer_heads=8, |
|
transformer_layers=12, |
|
tempearture_init=temperature_init, |
|
**kwargs |
|
) |
|
model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False) |
|
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict()) |
|
model.positional_embedding.data.copy_(clip_model.positional_embedding.data) |
|
model.ln_final.load_state_dict(clip_model.ln_final.state_dict()) |
|
if project_embed_dim == clip_model.text_projection.shape[1]: |
|
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly") |
|
model.image_projection.data.copy_(clip_model.visual.proj.data) |
|
model.text_projection.data.copy_(clip_model.text_projection.data) |
|
model.logit_scale.data.copy_(clip_model.logit_scale.data) |
|
return model |
|
|
|
|
|
|