Spaces:
Paused
Paused
import math | |
from functools import reduce | |
from operator import mul | |
from einops import rearrange, repeat | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class PromptLearner(nn.Module): | |
def __init__(self, ctx_dim=512, n_ctx=16): | |
super(PromptLearner, self).__init__() | |
self.n_ctx = n_ctx | |
self.ctx_dim = ctx_dim | |
# initialize prompts | |
ctx_vectors = torch.empty(n_ctx, ctx_dim) | |
nn.init.normal_(ctx_vectors, std=0.02) | |
prompt_prefix = " ".join(["X"] * n_ctx) | |
self.ctx = nn.Parameter(ctx_vectors) # to be optimized | |
print(f'Initial context: "{prompt_prefix}"') | |
print(f"Number of context words (tokens): {n_ctx}") | |
def forward(self): | |
return self.ctx | |
class PromptPoolLearner(nn.Module): | |
def __init__(self, prompt_dim=256, size=128, length=1): | |
super(PromptPoolLearner, self).__init__() | |
self.prompt_dim = prompt_dim | |
self.length = length | |
self.size = size | |
# initiate prompt | |
self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim)) | |
self.id_table = torch.ones([size]).cuda() | |
# xavier_uniform initialization | |
nn.init.uniform_(self.prompt_values.data, -1, 1) | |
def l2_normalize(self, x, dim=None, epsilon=1e-12): | |
"""Normalizes a given vector or matrix.""" | |
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) | |
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device))) | |
return x * x_inv_norm | |
def forward(self, query, k=0, istrain=False, gamma=1.0): | |
BZ = query.shape[0] | |
out = dict() | |
query = self.l2_normalize(query.squeeze(1), dim=1) | |
keys = self.prompt_values.mean(dim=1) | |
keys = self.l2_normalize(keys, dim=1) | |
similarity = torch.matmul(query, keys.t()) | |
if k > 0 and k < self.size: | |
if istrain: | |
inv_freq = self.id_table.sum() / self.id_table.float() | |
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1) | |
idx = torch.multinomial(weights, k, replacement=False) | |
else: | |
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k] | |
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) | |
self.id_table[prompt_id] += id_counts | |
prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim) | |
else: | |
idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1) | |
prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1) | |
prompts = self.l2_normalize(prompts, dim=-1) | |
out['prompts'] = prompts | |
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx] | |
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim) | |
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ | |
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ | |
out['ps_loss'] = diff + ksim | |
return out | |
class VisualPromptLearner(nn.Module): | |
def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False, | |
deep_shared=False, split_st=False, dropout=0.1, pool={}): | |
super(VisualPromptLearner, self).__init__() | |
self.num_layers = num_layers | |
self.embed_dim = embed_dim | |
self.prompt_dim = prompt_dim | |
self.num_tokens = num_tokens # number of prompted tokens | |
self.prompt_dropout = nn.Dropout(dropout) | |
pool_size = pool.get('size', 0) | |
self.pool_length = pool.get('length', 1) | |
self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False | |
if self.use_bank: | |
print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})') | |
if prompt_dim != embed_dim: | |
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False) | |
else: | |
self.prompt_inproj = nn.Identity() | |
if self.use_bank: | |
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False) | |
nn.init.kaiming_normal_( | |
self.prompt_outproj.weight, a=0, mode='fan_out') | |
else: | |
self.prompt_outproj = nn.Identity() | |
self.split_st = split_st # split spatial and temporal prompts | |
# initiate prompt: | |
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim)) | |
if split_st: | |
if self.use_bank: | |
pool['size'] //= 2 | |
self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool) | |
self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool) | |
else: | |
self.spatial_prompt_embeddings = nn.Parameter(torch.zeros( | |
1, num_tokens // 2, prompt_dim)) | |
self.temporal_prompt_embeddings = nn.Parameter(torch.zeros( | |
1, num_tokens // 2, prompt_dim)) | |
# xavier_uniform initialization | |
nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val) | |
nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val) | |
else: | |
if self.use_bank: | |
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool) | |
else: | |
self.prompt_embeddings = nn.Parameter(torch.zeros( | |
1, num_tokens, prompt_dim)) | |
# xavier_uniform initialization | |
nn.init.uniform_(self.prompt_embeddings.data, -val, val) | |
self.deep = deep or deep_shared | |
self.deep_shared = deep_shared | |
if deep and (not deep_shared): | |
total_d_layer = num_layers - 1 | |
if split_st: | |
if self.use_bank: | |
self.spatial_deep_prompt_pool = nn.ModuleList([ | |
PromptPoolLearner(prompt_dim, **pool) | |
for i in range(total_d_layer)]) | |
self.temporal_deep_prompt_pool = nn.ModuleList([ | |
PromptPoolLearner(prompt_dim, **pool) | |
for i in range(total_d_layer)]) | |
else: | |
self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros( | |
total_d_layer, num_tokens // 2, prompt_dim)) | |
self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros( | |
total_d_layer, num_tokens // 2, prompt_dim)) | |
# xavier_uniform initialization | |
nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val) | |
nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val) | |
else: | |
if self.use_bank: | |
self.deep_prompt_pool = nn.ModuleList([ | |
PromptPoolLearner(prompt_dim, **pool) | |
for i in range(total_d_layer)]) | |
else: | |
self.deep_prompt_embeddings = nn.Parameter(torch.zeros( | |
total_d_layer, num_tokens, prompt_dim)) | |
# xavier_uniform initialization | |
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) | |
def forward(self, query=None, layer=0, istrain=False, gamma=1.0): | |
query = query.detach() | |
query = self.prompt_inproj(query) | |
ps_loss = query.new_zeros([1]) | |
if self.split_st: | |
if self.deep and (not self.deep_shared) and layer > 0: | |
if self.use_bank: | |
k = (self.num_tokens // 2) // self.pool_length | |
spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma) | |
spatial_prompts = spatial_out['prompts'] | |
temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma) | |
temporal_prompts = temporal_out['prompts'] | |
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0) | |
else: | |
spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1] | |
temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1] | |
else: | |
if self.use_bank: | |
k = (self.num_tokens // 2) // self.pool_length | |
spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma) | |
spatial_prompts = spatial_out['prompts'] | |
temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma) | |
temporal_prompts = temporal_out['prompts'] | |
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0) | |
else: | |
spatial_prompts = self.spatial_prompt_embeddings | |
temporal_prompts = self.temporal_prompt_embeddings | |
prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1) | |
else: | |
if self.deep and (not self.deep_shared) and layer > 0: | |
if self.use_bank: | |
k = self.num_tokens // self.pool_length | |
out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma) | |
prompts = out['prompts'] | |
ps_loss += out.get('ps_loss', 0) | |
else: | |
prompts = self.deep_prompt_embeddings[layer-1] | |
else: | |
if self.use_bank: | |
k = self.num_tokens // self.pool_length | |
out = self.prompt_pool(query, k, istrain, gamma) | |
prompts = out['prompts'] | |
ps_loss += out.get('ps_loss', 0) | |
else: | |
prompts = self.prompt_embeddings | |
prompts = self.prompt_dropout(self.prompt_outproj(prompts)) | |
return prompts, ps_loss | |
class CMM(nn.Module): | |
'''Context modeling module''' | |
def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}): | |
super(CMM, self).__init__() | |
self.num_tokens = num_tokens | |
self.num_frames = num_frames | |
self.embed_dim = embed_dim | |
self.prompt_dim = prompt_dim | |
self.pool_size = pool.get('size', 0) | |
self.pool_length = pool.get('length', 1) | |
self.use_bank = True if self.pool_size > 0 else False | |
self.use_rnn = not self.use_bank | |
if self.use_rnn: | |
self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, | |
num_layers=1, batch_first=True, dropout=dropout, bidirectional=True) | |
self.shared = shared | |
self.prompt_dropout = nn.Dropout(dropout) | |
if self.use_bank: | |
print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})') | |
if self.use_rnn: | |
self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim) | |
nn.init.kaiming_normal_( | |
self.prompt_inproj.weight, a=0, mode='fan_out') | |
else: | |
if embed_dim != prompt_dim: | |
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False) | |
else: | |
self.prompt_inproj = nn.Identity() | |
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False) | |
nn.init.kaiming_normal_( | |
self.prompt_outproj.weight, a=0, mode='fan_out') | |
if shared: | |
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool) | |
else: | |
self.prompt_pool = nn.ModuleList([ | |
PromptPoolLearner(prompt_dim, **pool) | |
for i in range(num_layer)]) | |
else: | |
self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens) | |
def forward(self, x, layer=0, istrain=False, gamma=1.0): | |
BZ = x.size(0) | |
x = x.detach() | |
x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames) | |
x = torch.mean(x, dim=2) | |
if self.use_rnn: | |
x, _ = self.rnn(x) | |
ps_loss = x.new_zeros([1]) | |
if self.use_bank: | |
query = self.prompt_inproj(x).flatten(0, 1) | |
k = self.num_tokens // self.pool_length | |
if self.shared: | |
out = self.prompt_pool(query, k, istrain, gamma) | |
else: | |
out = self.prompt_pool[layer](query, k, istrain, gamma) | |
prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames) | |
prompts = self.prompt_outproj(prompts) | |
ps_loss += out.get('ps_loss', 0) * self.num_frames | |
else: | |
prompts = self.fc(x) | |
prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens) | |
return prompts, ps_loss | |