|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Implementations of Video Transformers in PyTorch |
|
A PyTorch implementation of space-time transformer as described in |
|
'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650 |
|
A PyTorch implementation of timesformer as described in |
|
'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095 |
|
Acknowledgments: |
|
- This code builds on Ross Wightman's vision_transformer code in pytorch-image-models: |
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
|
- It is also inspired by lucidrains timesformer implementation: |
|
https://github.com/lucidrains/TimeSformer-pytorch |
|
Hacked together by Max Bain |
|
""" |
|
|
|
from collections import OrderedDict, defaultdict |
|
from functools import partial, reduce |
|
import operator |
|
import copy |
|
|
|
import torch |
|
import torch.utils.checkpoint as checkpoint |
|
from einops import rearrange, repeat |
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ |
|
from torch import einsum, nn |
|
import torch.nn.functional as F |
|
import pdb |
|
|
|
from lavila.models.prompt_tuning import VisualPromptLearner, CMM |
|
|
|
|
|
def attn(q, k, v): |
|
sim = einsum('b i d, b j d -> b i j', q, k) |
|
attn = sim.softmax(dim=-1) |
|
out = einsum('b i j, b j d -> b i d', attn, v) |
|
return out |
|
|
|
|
|
class Mlp(nn.Module): |
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class VideoPatchEmbed(nn.Module): |
|
""" Video to Patch Embedding |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, |
|
num_frames=8, ln_pre=False): |
|
super().__init__() |
|
img_size = to_2tuple(img_size) |
|
patch_size = to_2tuple(patch_size) |
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.num_patches = num_patches |
|
self.num_frames = num_frames |
|
self.embed_dim = embed_dim |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre) |
|
|
|
def forward(self, x): |
|
B, F, C, H, W = x.shape |
|
assert F <= self.num_frames |
|
x = x.view(-1, C, H, W) |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
class VarAttention(nn.Module): |
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., |
|
initialize='random', num_tokens=0): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.proj = nn.Linear(dim, dim) |
|
if initialize == 'zeros': |
|
self.qkv.weight.data.fill_(0) |
|
self.qkv.bias.data.fill_(0) |
|
|
|
|
|
self.proj.weight.data.fill_(1) |
|
self.proj.bias.data.fill_(0) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
self.num_tokens = num_tokens |
|
|
|
def forward(self, x, einops_from, einops_to, einops_dims, cfg): |
|
style = cfg.get('style', 'default') |
|
pt_att = cfg.get('pt_att', True) |
|
n_seg = cfg.get('n_seg', 4) |
|
if 'VoP' in style: |
|
return self.forward_VoP(x, einops_from, einops_to, einops_dims, n_seg) |
|
elif style == 'attall': |
|
return self.forward_attall(x, pt_att) |
|
else: |
|
return self.forward_features(x, einops_from, einops_to, einops_dims, pt_att) |
|
|
|
def forward_features(self, x, einops_from, einops_to, einops_dims, pt_att=True): |
|
h = self.num_heads |
|
num_tokens = self.num_tokens |
|
if self.num_tokens > 0 and not pt_att: |
|
prompts = x[:, 1:self.num_tokens+1, :] |
|
x = torch.cat(( |
|
x[:, :1, :], |
|
x[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
num_tokens = 0 |
|
|
|
|
|
q, k, v = self.qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
q *= self.scale |
|
|
|
|
|
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) |
|
|
|
|
|
cls_out = attn(cls_q, k, v) |
|
|
|
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) |
|
|
|
|
|
r = q_.shape[0] // cls_k.shape[0] |
|
cls_k, cls_v = map(lambda t: repeat(t, 'b p d -> (b r) p d', r=r), (cls_k, cls_v)) |
|
k_ = torch.cat((cls_k, k_), dim=1) |
|
v_ = torch.cat((cls_v, v_), dim=1) |
|
|
|
|
|
out = attn(q_, k_, v_) |
|
|
|
|
|
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) |
|
|
|
|
|
out = torch.cat((cls_out, out), dim=1) |
|
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
if self.num_tokens > 0 and not pt_att: |
|
out = torch.cat(( |
|
out[:, :1, :], |
|
prompts, |
|
out[:, 1:, :] |
|
), dim=1) |
|
|
|
|
|
x = self.proj(out) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
def forward_VoP(self, x, einops_from, einops_to, einops_dims, n_seg=4): |
|
|
|
h = self.num_heads |
|
num_tokens = self.num_tokens |
|
|
|
|
|
q, k, v = self.qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
q *= self.scale |
|
|
|
|
|
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) |
|
|
|
cls_out = attn(cls_q[:, :1, :], k, v) |
|
|
|
|
|
pstep = num_tokens // n_seg |
|
pseg = [range(st, en) for st, en in zip(range(1, num_tokens+1, pstep), range(pstep+1, num_tokens+2, pstep))] |
|
p_q, p_k, p_v = map(lambda t: rearrange(t[:, pseg, :], 'b s p d -> (b s) p d'), (cls_q, cls_k, cls_v)) |
|
|
|
|
|
q_, k_, v_ = map(lambda t: rearrange(t, 'b (f n) d -> b f n d', **einops_dims), (q_, k_, v_)) |
|
num_frames = k_.size(1) |
|
tstep = num_frames // n_seg |
|
tseg = [range(st, en) for st, en in zip(range(0, num_frames, tstep), range(tstep, num_frames+1, tstep))] |
|
q_, k_, v_ = map(lambda t: t[:, tseg, ...], (q_, k_, v_)) |
|
q_, k_, v_ = map(lambda t: rearrange(t, 'b s f n d -> (b s) (f n) d'), (q_, k_, v_)) |
|
|
|
|
|
k_, v_ = map(lambda t: torch.cat((t[0], t[1]), dim=1), ((p_k, k_), (p_v, v_))) |
|
p_out = attn(p_q, k_, v_) |
|
out = attn(q_, k_, v_) |
|
p_out = rearrange(p_out, '(b s) p d -> b (s p) d', s=n_seg) |
|
out = rearrange(out, '(b s) (f n) d -> b (s f n) d', s=n_seg, f=tstep) |
|
|
|
|
|
out = torch.cat((cls_out, p_out, out), dim=1) |
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
|
|
|
|
x = self.proj(out) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
def forward_attall(self, x, pt_att=True): |
|
h = self.num_heads |
|
if self.num_tokens > 0 and not pt_att: |
|
prompts = x[:, 1:self.num_tokens+1, :] |
|
x = torch.cat(( |
|
x[:, :1, :], |
|
x[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
|
|
|
|
q, k, v = self.qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
q *= self.scale |
|
|
|
|
|
out = attn(q, k, v) |
|
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
if self.num_tokens > 0 and not pt_att: |
|
out = torch.cat(( |
|
out[:, :1, :], |
|
prompts, |
|
out[:, 1:, :] |
|
), dim=1) |
|
|
|
|
|
x = self.proj(out) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
|
|
class SpaceTimeBlock(nn.Module): |
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros', |
|
attention_style='frozen-in-time', is_tanh_gating=False, num_tokens=0, split_st=False): |
|
super().__init__() |
|
|
|
self.split_st = split_st |
|
if split_st: |
|
num_tokens = num_tokens // 2 |
|
self.num_tokens = num_tokens |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = VarAttention( |
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens) |
|
|
|
self.timeattn = VarAttention( |
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens, |
|
initialize=time_init) |
|
|
|
if is_tanh_gating: |
|
self.alpha_timeattn = nn.Parameter(torch.zeros([])) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
self.norm3 = norm_layer(dim) |
|
|
|
self.attention_style = attention_style |
|
|
|
def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time, |
|
time_n, space_f, use_checkpoint=False, pt_spt=True, pt_tmp=True, style='default', n_seg=4): |
|
if self.split_st: |
|
spatial_prompts = x[:, 1:self.num_tokens+1, :] |
|
x = torch.cat(( |
|
x[:, :1, :], |
|
x[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
|
|
if use_checkpoint: |
|
time_output = checkpoint.checkpoint( |
|
self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp} |
|
) |
|
else: |
|
time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp}) |
|
if hasattr(self, "alpha_timeattn"): |
|
time_output = torch.tanh(self.alpha_timeattn) * time_output |
|
time_residual = x + time_output |
|
|
|
if self.split_st: |
|
temporal_prompts = time_residual[:, 1:self.num_tokens+1, :] |
|
time_residual = torch.cat(( |
|
time_residual[:, :1, :], |
|
spatial_prompts, |
|
time_residual[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
|
|
cfg = {'style': style, 'pt_att': pt_spt, 'n_seg': n_seg} |
|
if use_checkpoint: |
|
space_output = checkpoint.checkpoint( |
|
self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}, cfg |
|
) |
|
else: |
|
space_output = self.attn(self.norm1(time_residual), einops_from_space, |
|
einops_to_space, {"f": space_f}, cfg) |
|
if self.attention_style == 'frozen-in-time': |
|
space_residual = x + self.drop_path(space_output) |
|
else: |
|
raise NotImplementedError |
|
|
|
if self.split_st: |
|
space_residual = torch.cat(( |
|
space_residual[:, :self.num_tokens+1, :], |
|
temporal_prompts, |
|
space_residual[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
|
|
x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual))) |
|
|
|
return x |
|
|
|
|
|
class SpaceTimeTransformer(nn.Module): |
|
""" Vision Transformer |
|
A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain. |
|
https://arxiv.org/abs/2104.00650 |
|
Based off: |
|
- ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py] |
|
lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch]. |
|
Notable differences: |
|
- allows for variable length input frames (<= num_frames) |
|
- allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED] |
|
- different attention block mechanism |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, |
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, |
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, |
|
num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False, |
|
act_layer=nn.GELU, is_tanh_gating=False, tune_bias=False, prompt_cfg={}): |
|
""" |
|
Args: |
|
img_size (int, tuple): input image size |
|
patch_size (int, tuple): patch size |
|
in_chans (int): number of input channels |
|
num_classes (int): number of classes for classification head |
|
embed_dim (int): embedding dimension |
|
depth (int): depth of transformer |
|
num_heads (int): number of attention heads |
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
|
qkv_bias (bool): enable bias for qkv if True |
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
|
drop_rate (float): dropout rate |
|
attn_drop_rate (float): attention dropout rate |
|
drop_path_rate (float): stochastic depth rate |
|
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module |
|
norm_layer: (nn.Module): normalization layer |
|
num_frames: (int) maximum number of frames expected as input |
|
time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off |
|
as ViT. |
|
attention_style: (str) how to attend to space and time. |
|
""" |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.num_features = self.embed_dim = embed_dim |
|
self.num_frames = num_frames |
|
self.embed_dim = embed_dim |
|
self.tune_bias = tune_bias |
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
print("######USING ATTENTION STYLE: ", attention_style) |
|
self.param_list = [] |
|
if hybrid_backbone is not None: |
|
raise NotImplementedError('hybrid backbone not implemented') |
|
else: |
|
self.patch_embed = VideoPatchEmbed( |
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre) |
|
self.param_list += list(self.patch_embed.parameters()) |
|
num_patches = self.patch_embed.num_patches |
|
self.patches_per_frame = num_patches // num_frames |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, self.patches_per_frame + 1, |
|
embed_dim)) |
|
self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) |
|
self.param_list += [self.cls_token, self.pos_embed, self.temporal_embed] |
|
|
|
if ln_pre: |
|
self.ln_pre = nn.LayerNorm(embed_dim) |
|
if self.tune_bias: |
|
self.param_list += [m for n, m in self.ln_pre.named_parameters() if 'bias' not in n] |
|
else: |
|
self.param_list += list(self.ln_pre.parameters()) |
|
else: |
|
self.ln_pre = None |
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
|
|
|
self.num_tokens = prompt_cfg.get('num_tokens', 0) |
|
self.prompt_dim = prompt_cfg.get('prompt_dim', 768) |
|
self.pt_spt = prompt_cfg.pop('pt_spt', True) |
|
self.pt_tmp = prompt_cfg.pop('pt_tmp', True) |
|
self.style = prompt_cfg.pop('style', 'default') |
|
self.query = prompt_cfg.pop('query', 'cls') |
|
self.n_seg = prompt_cfg.pop('n_seg', 4) |
|
self.k_s = prompt_cfg.pop('K_s', depth) |
|
self.st = prompt_cfg.pop('st', 0) |
|
self.end = prompt_cfg.pop('end', depth) |
|
assert self.st <= self.end |
|
if self.style == 'default': |
|
print(f'Prompting {self.st}-{self.end} layer of the visual backbone') |
|
elif self.style == 'VoP_c' and self.k_s < depth: |
|
self.prompt_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) |
|
elif self.style == 'VoP_c_pool': |
|
self.prompt_temp_embed = nn.Parameter(torch.zeros(1, self.n_seg, embed_dim)) |
|
trunc_normal_(self.prompt_temp_embed, std=.02) |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
|
blocks = [] |
|
for i in range(depth): |
|
stblk_cfg = {} |
|
if self.num_tokens > 0: |
|
stblk_cfg = {'num_tokens': prompt_cfg['num_tokens'], 'split_st': prompt_cfg.get('split_st', False)} |
|
blocks.append( |
|
SpaceTimeBlock( |
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init, |
|
attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating, **stblk_cfg) |
|
) |
|
|
|
self.blocks = nn.ModuleList(blocks) |
|
self.norm = norm_layer(embed_dim) |
|
if self.tune_bias: |
|
self.param_list += reduce(operator.add, [[m for n, m in x.named_parameters() if 'bias' not in n] for x in self.blocks]) |
|
self.param_list += [m for n, m in self.norm.named_parameters() if 'bias' not in n] |
|
else: |
|
self.param_list += reduce(operator.add, [list(x.parameters()) for x in self.blocks]) |
|
self.param_list += list(self.norm.parameters()) |
|
|
|
|
|
if representation_size: |
|
self.num_features = representation_size |
|
self.pre_logits = nn.Sequential(OrderedDict([ |
|
('fc', nn.Linear(embed_dim, representation_size)), |
|
('act', nn.Tanh()) |
|
])) |
|
if self.tune_bias: |
|
self.param_list += [m for n, m in self.pre_logits.named_parameters() if 'bias' not in n] |
|
else: |
|
self.param_list += list(self.pre_logits.parameters()) |
|
else: |
|
self.pre_logits = nn.Identity() |
|
|
|
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
trunc_normal_(self.pos_embed, std=.02) |
|
trunc_normal_(self.cls_token, std=.02) |
|
|
|
|
|
if num_frames == 1: |
|
self.apply(self._init_weights) |
|
|
|
|
|
self.einops_from_space = 'b (f n) d' |
|
self.einops_to_space = '(b f) n d' |
|
self.einops_from_time = 'b (f n) d' |
|
self.einops_to_time = '(b n) f d' |
|
|
|
|
|
self.prompt_learner = None |
|
if self.num_tokens > 0: |
|
if 'VoP_c' in self.style: |
|
pool = prompt_cfg.pop('pool', {}) if 'pool' in self.style else {} |
|
if self.k_s > 0: |
|
self.prompt_generator = CMM(self.num_tokens // self.n_seg, self.n_seg, embed_dim, self.prompt_dim, num_layer=self.k_s, \ |
|
shared=prompt_cfg.get('deep_shared', False), pool=pool) |
|
n_prompt_layer = depth - self.k_s |
|
|
|
else: |
|
n_prompt_layer = self.end - self.st |
|
|
|
if n_prompt_layer > 0: |
|
prompt_cfg['num_layers'] = n_prompt_layer |
|
prompt_cfg['prompt_dim'] = embed_dim |
|
self.prompt_learner = VisualPromptLearner(patch_size, embed_dim, **prompt_cfg) |
|
|
|
for p in self.param_list: |
|
p.requies_grad = False |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'pos_embed', 'cls_token'} |
|
|
|
def get_classifier(self): |
|
return self.head |
|
|
|
def reset_classifier(self, num_classes, global_pool=''): |
|
self.num_classes = num_classes |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
def forward_features(self, x, use_checkpoint=False, cls_at_last=True, istrain=False, gamma=1.0): |
|
|
|
b, curr_frames, channels, _, _ = x.shape |
|
x = self.patch_embed(x) |
|
x = x.flatten(2).transpose(2, 1) |
|
x = x.reshape(b, -1, self.patch_embed.embed_dim) |
|
|
|
BF = x.shape[0] |
|
cls_tokens = self.cls_token.expand(BF, -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) |
|
tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1) |
|
|
|
tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1) |
|
total_pos_embed = tile_pos_embed + tile_temporal_embed |
|
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) |
|
|
|
curr_patches = x.shape[1] |
|
x = x + total_pos_embed[:, :curr_patches] |
|
ps_loss = x.new_zeros([1]) |
|
|
|
if self.num_tokens > 0: |
|
if 'VoP_c' in self.style and self.k_s > 0: |
|
ctx, ps = self.prompt_generator(x[:, 1:, :], 0, istrain=istrain, gamma=gamma) |
|
ps_loss += ps |
|
if self.prompt_generator.use_bank: |
|
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) |
|
ctx = ctx + prompt_temp_embed |
|
|
|
elif self.prompt_learner is not None: |
|
ctx, ps = self.prompt_learner(x[:, :1, :], 0, istrain=istrain, gamma=gamma) |
|
ps_loss += ps |
|
if ctx.size(0) != BF: |
|
ctx = ctx.expand(BF, -1, -1) |
|
|
|
x = torch.cat(( |
|
x[:, :1, :], |
|
ctx, |
|
x[:, 1:, :] |
|
), dim=1) |
|
|
|
if self.ln_pre is not None: |
|
x = self.ln_pre(x) |
|
x = self.pos_drop(x) |
|
n = self.patches_per_frame |
|
f = curr_frames |
|
|
|
for i, blk in enumerate(self.blocks): |
|
if self.num_tokens > 0 and i > 0 and i >= self.st and i < self.end: |
|
if 'VoP_c' in self.style: |
|
if i < self.k_s: |
|
ctx, ps = self.prompt_generator(x[:, self.num_tokens+1:, :], i, istrain=istrain, gamma=gamma) |
|
ps_loss += ps |
|
if self.prompt_generator.use_bank: |
|
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) |
|
ctx = ctx + prompt_temp_embed |
|
else: |
|
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.k_s, istrain=istrain, gamma=gamma) |
|
ps_loss += ps |
|
|
|
if 'pool' in self.style: |
|
prompt_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1) |
|
else: |
|
prompt_embed = self.prompt_embed.repeat_interleave(self.num_tokens // self.num_frames, 1) |
|
ctx = ctx + prompt_embed |
|
if ctx.size(0) != BF: |
|
ctx = ctx.expand(BF, -1, -1) |
|
|
|
elif (i - self.st) < self.prompt_learner.num_layers: |
|
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.st, istrain=istrain, gamma=gamma) |
|
ps_loss += ps |
|
if ctx.size(0) != BF: |
|
ctx = ctx.expand(BF, -1, -1) |
|
|
|
x = torch.cat(( |
|
x[:, :1, :], |
|
ctx, |
|
x[:, self.num_tokens+1:, :] |
|
), dim=1) |
|
|
|
style = 'default' if i >= self.k_s else self.style |
|
pt_tmp = self.pt_tmp if i >= self.st and i < self.end else False |
|
pt_spt = self.pt_spt if i >= self.st and i < self.end else False |
|
x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time, |
|
self.einops_to_time, |
|
time_n=n, space_f=f, use_checkpoint=use_checkpoint, pt_spt=pt_spt, |
|
pt_tmp=pt_tmp, style=style, n_seg=self.n_seg) |
|
|
|
if cls_at_last: |
|
x = self.norm(x) |
|
x = x[:, 0] |
|
x = self.pre_logits(x) |
|
|
|
return x, ps_loss |
|
else: |
|
return self.norm(x), ps_loss |
|
|
|
def forward(self, x, use_checkpoint=False, istrain=False, gamma=1.0): |
|
|
|
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous() |
|
x, ps_loss = self.forward_features(x, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma) |
|
x = self.head(x) |
|
|
|
return x, ps_loss |
|
|
|
def train(self, mode=True): |
|
if not isinstance(mode, bool): |
|
raise ValueError("training mode is expected to be boolean") |
|
self.training = mode |
|
for m in self.modules(): |
|
m.training = mode |
|
|
|
if mode and self.num_tokens > 0: |
|
for n, m in self.named_modules(): |
|
if 'prompt' not in n: |
|
m.training = False |
|
|
|
|