|
|
|
import math |
|
from functools import reduce |
|
from operator import mul |
|
from einops import rearrange, repeat |
|
import pdb |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class PromptLearner(nn.Module): |
|
def __init__(self, ctx_dim=512, n_ctx=16): |
|
super(PromptLearner, self).__init__() |
|
self.n_ctx = n_ctx |
|
self.ctx_dim = ctx_dim |
|
|
|
|
|
ctx_vectors = torch.empty(n_ctx, ctx_dim) |
|
nn.init.normal_(ctx_vectors, std=0.02) |
|
prompt_prefix = " ".join(["X"] * n_ctx) |
|
self.ctx = nn.Parameter(ctx_vectors) |
|
print(f'Initial context: "{prompt_prefix}"') |
|
print(f"Number of context words (tokens): {n_ctx}") |
|
|
|
def forward(self): |
|
return self.ctx |
|
|
|
class PromptPoolLearner(nn.Module): |
|
def __init__(self, prompt_dim=256, size=128, length=1): |
|
super(PromptPoolLearner, self).__init__() |
|
self.prompt_dim = prompt_dim |
|
self.length = length |
|
self.size = size |
|
|
|
|
|
self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim)) |
|
self.id_table = torch.ones([size]).cuda() |
|
|
|
|
|
nn.init.uniform_(self.prompt_values.data, -1, 1) |
|
|
|
def l2_normalize(self, x, dim=None, epsilon=1e-12): |
|
"""Normalizes a given vector or matrix.""" |
|
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) |
|
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device))) |
|
return x * x_inv_norm |
|
|
|
def forward(self, query, k=0, istrain=False, gamma=1.0): |
|
BZ = query.shape[0] |
|
out = dict() |
|
query = self.l2_normalize(query.squeeze(1), dim=1) |
|
keys = self.prompt_values.mean(dim=1) |
|
keys = self.l2_normalize(keys, dim=1) |
|
similarity = torch.matmul(query, keys.t()) |
|
|
|
if k > 0 and k < self.size: |
|
|
|
if istrain: |
|
inv_freq = self.id_table.sum() / self.id_table.float() |
|
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1) |
|
idx = torch.multinomial(weights, k, replacement=False) |
|
else: |
|
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k] |
|
|
|
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) |
|
self.id_table[prompt_id] += id_counts |
|
prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim) |
|
else: |
|
idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1) |
|
prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1) |
|
|
|
prompts = self.l2_normalize(prompts, dim=-1) |
|
out['prompts'] = prompts |
|
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx] |
|
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim) |
|
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ |
|
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ |
|
out['ps_loss'] = diff + ksim |
|
|
|
return out |
|
|
|
|
|
class VisualPromptLearner(nn.Module): |
|
def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False, |
|
deep_shared=False, split_st=False, dropout=0.1, pool={}): |
|
super(VisualPromptLearner, self).__init__() |
|
self.num_layers = num_layers |
|
self.embed_dim = embed_dim |
|
self.prompt_dim = prompt_dim |
|
self.num_tokens = num_tokens |
|
self.prompt_dropout = nn.Dropout(dropout) |
|
pool_size = pool.get('size', 0) |
|
self.pool_length = pool.get('length', 1) |
|
self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False |
|
if self.use_bank: |
|
print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})') |
|
|
|
if prompt_dim != embed_dim: |
|
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False) |
|
else: |
|
self.prompt_inproj = nn.Identity() |
|
|
|
if self.use_bank: |
|
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False) |
|
nn.init.kaiming_normal_( |
|
self.prompt_outproj.weight, a=0, mode='fan_out') |
|
else: |
|
self.prompt_outproj = nn.Identity() |
|
|
|
self.split_st = split_st |
|
|
|
|
|
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim)) |
|
if split_st: |
|
if self.use_bank: |
|
pool['size'] //= 2 |
|
self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool) |
|
self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool) |
|
else: |
|
self.spatial_prompt_embeddings = nn.Parameter(torch.zeros( |
|
1, num_tokens // 2, prompt_dim)) |
|
self.temporal_prompt_embeddings = nn.Parameter(torch.zeros( |
|
1, num_tokens // 2, prompt_dim)) |
|
|
|
nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val) |
|
nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val) |
|
else: |
|
if self.use_bank: |
|
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool) |
|
else: |
|
self.prompt_embeddings = nn.Parameter(torch.zeros( |
|
1, num_tokens, prompt_dim)) |
|
|
|
nn.init.uniform_(self.prompt_embeddings.data, -val, val) |
|
|
|
self.deep = deep or deep_shared |
|
self.deep_shared = deep_shared |
|
if deep and (not deep_shared): |
|
total_d_layer = num_layers - 1 |
|
if split_st: |
|
if self.use_bank: |
|
self.spatial_deep_prompt_pool = nn.ModuleList([ |
|
PromptPoolLearner(prompt_dim, **pool) |
|
for i in range(total_d_layer)]) |
|
self.temporal_deep_prompt_pool = nn.ModuleList([ |
|
PromptPoolLearner(prompt_dim, **pool) |
|
for i in range(total_d_layer)]) |
|
else: |
|
self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros( |
|
total_d_layer, num_tokens // 2, prompt_dim)) |
|
self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros( |
|
total_d_layer, num_tokens // 2, prompt_dim)) |
|
|
|
nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val) |
|
nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val) |
|
else: |
|
if self.use_bank: |
|
self.deep_prompt_pool = nn.ModuleList([ |
|
PromptPoolLearner(prompt_dim, **pool) |
|
for i in range(total_d_layer)]) |
|
else: |
|
self.deep_prompt_embeddings = nn.Parameter(torch.zeros( |
|
total_d_layer, num_tokens, prompt_dim)) |
|
|
|
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) |
|
|
|
def forward(self, query=None, layer=0, istrain=False, gamma=1.0): |
|
query = query.detach() |
|
query = self.prompt_inproj(query) |
|
ps_loss = query.new_zeros([1]) |
|
if self.split_st: |
|
if self.deep and (not self.deep_shared) and layer > 0: |
|
if self.use_bank: |
|
k = (self.num_tokens // 2) // self.pool_length |
|
spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma) |
|
spatial_prompts = spatial_out['prompts'] |
|
temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma) |
|
temporal_prompts = temporal_out['prompts'] |
|
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0) |
|
else: |
|
spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1] |
|
temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1] |
|
else: |
|
if self.use_bank: |
|
k = (self.num_tokens // 2) // self.pool_length |
|
spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma) |
|
spatial_prompts = spatial_out['prompts'] |
|
temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma) |
|
temporal_prompts = temporal_out['prompts'] |
|
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0) |
|
else: |
|
spatial_prompts = self.spatial_prompt_embeddings |
|
temporal_prompts = self.temporal_prompt_embeddings |
|
|
|
prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1) |
|
|
|
else: |
|
if self.deep and (not self.deep_shared) and layer > 0: |
|
if self.use_bank: |
|
k = self.num_tokens // self.pool_length |
|
out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma) |
|
prompts = out['prompts'] |
|
ps_loss += out.get('ps_loss', 0) |
|
else: |
|
prompts = self.deep_prompt_embeddings[layer-1] |
|
else: |
|
if self.use_bank: |
|
k = self.num_tokens // self.pool_length |
|
out = self.prompt_pool(query, k, istrain, gamma) |
|
prompts = out['prompts'] |
|
ps_loss += out.get('ps_loss', 0) |
|
else: |
|
prompts = self.prompt_embeddings |
|
|
|
prompts = self.prompt_dropout(self.prompt_outproj(prompts)) |
|
return prompts, ps_loss |
|
|
|
|
|
class CMM(nn.Module): |
|
'''Context modeling module''' |
|
def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}): |
|
super(CMM, self).__init__() |
|
self.num_tokens = num_tokens |
|
self.num_frames = num_frames |
|
self.embed_dim = embed_dim |
|
self.prompt_dim = prompt_dim |
|
self.pool_size = pool.get('size', 0) |
|
self.pool_length = pool.get('length', 1) |
|
self.use_bank = True if self.pool_size > 0 else False |
|
self.use_rnn = not self.use_bank |
|
if self.use_rnn: |
|
self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, |
|
num_layers=1, batch_first=True, dropout=dropout, bidirectional=True) |
|
self.shared = shared |
|
self.prompt_dropout = nn.Dropout(dropout) |
|
|
|
if self.use_bank: |
|
print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})') |
|
if self.use_rnn: |
|
self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim) |
|
nn.init.kaiming_normal_( |
|
self.prompt_inproj.weight, a=0, mode='fan_out') |
|
else: |
|
if embed_dim != prompt_dim: |
|
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False) |
|
else: |
|
self.prompt_inproj = nn.Identity() |
|
|
|
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False) |
|
nn.init.kaiming_normal_( |
|
self.prompt_outproj.weight, a=0, mode='fan_out') |
|
|
|
if shared: |
|
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool) |
|
else: |
|
self.prompt_pool = nn.ModuleList([ |
|
PromptPoolLearner(prompt_dim, **pool) |
|
for i in range(num_layer)]) |
|
else: |
|
self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens) |
|
|
|
def forward(self, x, layer=0, istrain=False, gamma=1.0): |
|
BZ = x.size(0) |
|
x = x.detach() |
|
x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames) |
|
x = torch.mean(x, dim=2) |
|
|
|
if self.use_rnn: |
|
x, _ = self.rnn(x) |
|
|
|
ps_loss = x.new_zeros([1]) |
|
if self.use_bank: |
|
query = self.prompt_inproj(x).flatten(0, 1) |
|
k = self.num_tokens // self.pool_length |
|
if self.shared: |
|
out = self.prompt_pool(query, k, istrain, gamma) |
|
else: |
|
out = self.prompt_pool[layer](query, k, istrain, gamma) |
|
|
|
prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames) |
|
prompts = self.prompt_outproj(prompts) |
|
ps_loss += out.get('ps_loss', 0) * self.num_frames |
|
|
|
else: |
|
prompts = self.fc(x) |
|
prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens) |
|
|
|
return prompts, ps_loss |
|
|
|
|
|
|