File size: 10,994 Bytes
b3660df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lavila.models.openai_clip import load as load_openai_clip
from lavila.models.openai_model import QuickGELU, Transformer
from lavila.models.timesformer import SpaceTimeTransformer
from lavila.models.utils import remap_keys, rsetattr
from lavila.models.prompt_tuning import PromptLearner
class CLIP(nn.Module):
def __init__(self,
cfg,
embed_dim: int,
# vision
vision_width: int,
vision_model: nn.Module,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
tempearture_init=0.07,
**kwargs,
):
super().__init__()
self.context_length = context_length
self.vision_width = vision_width
self.tune_bias = cfg.get('tune_bias', False)
self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False)
self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False)
self.visual = vision_model
self.t_step = cfg.get('t_step', self.visual.num_frames)
txt_prompt_cfg = cfg.get('text_prompt', {})
self.n_ctx = txt_prompt_cfg.get('n_ctx', 0)
self.txt_use_bank = txt_prompt_cfg.get('use_bank', False)
if self.txt_use_bank:
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
prompt_cfg=txt_prompt_cfg,
prompt_learner=PromptLearner(transformer_width, self.n_ctx),
prompt_generator=self.visual.prompt_generator
)
else:
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
prompt_cfg=txt_prompt_cfg,
prompt_learner=PromptLearner(transformer_width, self.n_ctx)
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
print("=> initialize initial temperature with {}".format(tempearture_init))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
self.initialize_parameters()
freeze_list = []
if self.freeze_vis_backbone:
print("=> Freeze visual backbone")
freeze_list += self.visual.param_list + [self.image_projection]
if self.freeze_txt_backbone:
print("=> Freeze text backbone")
if self.tune_bias:
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n]
freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n]
else:
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n]
freeze_list += list(self.ln_final.parameters())
freeze_list += list(self.token_embedding.parameters())
freeze_list += [self.positional_embedding] + [self.text_projection]
for p in freeze_list:
p.requires_grad = False
# text prompts
if self.n_ctx > 0:
if self.txt_use_bank:
prompt_dim = self.visual.prompt_dim
if prompt_dim != transformer_width:
self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False)
else:
self.transformer.prompt_inproj = nn.Identity()
self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False)
nn.init.kaiming_normal_(
self.transformer.prompt_outproj.weight, a=0, mode='fan_out')
params_to_update = [n for n, m in self.named_parameters() if m.requires_grad]
num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad])
num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad])
print("=> Params to update: {}".format(params_to_update))
print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params))
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0):
x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
if isinstance(x, list):
assert len(x) == 1
x = x[0]
if apply_project:
x = x @ self.image_projection
return x, ps_loss
def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
B = x.shape[0]
eot = text.argmax(dim=-1)
x = x.permute(1, 0, 2) # NLD -> LND
x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection
return x, ps_loss
def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0):
image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
if norm_embed:
image_embed = F.normalize(image_embed, dim=-1)
text_embed = F.normalize(text_embed, dim=-1)
return {'image_embed': image_embed,
'text_embed': text_embed,
'logit_scale': self.logit_scale.exp(),
'ps_loss': ps_loss_img + ps_loss_txt}
def train(self, mode=True):
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for m in self.modules():
m.training = mode
if mode:
if self.freeze_vis_backbone and not self.tune_bias:
for n, m in self.visual.named_modules():
if 'prompt' not in n:
m.training = False
if self.freeze_txt_backbone and not self.tune_bias:
for n, m in self.transformer.named_modules():
if 'prompt' not in n:
m.training = False
self.token_embedding.training = False
self.ln_final.training = False
def CLIP_OPENAI_TIMESFORMER_BASE(
num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07,
project_embed_dim=256, **kwargs
):
cfg = kwargs.pop('model_cfg', {})
vision_model = SpaceTimeTransformer(
num_frames=num_frames,
time_init='zeros',
attention_style='frozen-in-time',
ln_pre=True,
act_layer=QuickGELU,
is_tanh_gating=timesformer_gated_xattn,
drop_path_rate=cfg.get('drop_path_rate', 0),
tune_bias=cfg.get('tune_bias', False),
prompt_cfg=cfg.get('visual_prompt', {})
)
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
print("=> Loading CLIP (ViT-B/16) weights")
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
print(res)
vision_model.head = nn.Identity()
vision_model.pre_logits = nn.Identity()
vision_model.fc = nn.Identity()
model = CLIP(
cfg,
embed_dim=project_embed_dim,
vision_width=768,
vision_model=vision_model,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
tempearture_init=temperature_init,
**kwargs
)
model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False)
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
if project_embed_dim == clip_model.text_projection.shape[1]:
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
model.image_projection.data.copy_(clip_model.visual.proj.data)
model.text_projection.data.copy_(clip_model.text_projection.data)
model.logit_scale.data.copy_(clip_model.logit_scale.data)
return model
|