EK100MIR / lavila /models /prompt_tuning.py
gina9726's picture
Upload 29 files
b3660df verified
import math
from functools import reduce
from operator import mul
from einops import rearrange, repeat
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
class PromptLearner(nn.Module):
def __init__(self, ctx_dim=512, n_ctx=16):
super(PromptLearner, self).__init__()
self.n_ctx = n_ctx
self.ctx_dim = ctx_dim
# initialize prompts
ctx_vectors = torch.empty(n_ctx, ctx_dim)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
def forward(self):
return self.ctx
class PromptPoolLearner(nn.Module):
def __init__(self, prompt_dim=256, size=128, length=1):
super(PromptPoolLearner, self).__init__()
self.prompt_dim = prompt_dim
self.length = length
self.size = size
# initiate prompt
self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim))
self.id_table = torch.ones([size]).cuda()
# xavier_uniform initialization
nn.init.uniform_(self.prompt_values.data, -1, 1)
def l2_normalize(self, x, dim=None, epsilon=1e-12):
"""Normalizes a given vector or matrix."""
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
return x * x_inv_norm
def forward(self, query, k=0, istrain=False, gamma=1.0):
BZ = query.shape[0]
out = dict()
query = self.l2_normalize(query.squeeze(1), dim=1)
keys = self.prompt_values.mean(dim=1)
keys = self.l2_normalize(keys, dim=1)
similarity = torch.matmul(query, keys.t())
if k > 0 and k < self.size:
if istrain:
inv_freq = self.id_table.sum() / self.id_table.float()
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
idx = torch.multinomial(weights, k, replacement=False)
else:
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
self.id_table[prompt_id] += id_counts
prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim)
else:
idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1)
prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1)
prompts = self.l2_normalize(prompts, dim=-1)
out['prompts'] = prompts
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
out['ps_loss'] = diff + ksim
return out
class VisualPromptLearner(nn.Module):
def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False,
deep_shared=False, split_st=False, dropout=0.1, pool={}):
super(VisualPromptLearner, self).__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.prompt_dim = prompt_dim
self.num_tokens = num_tokens # number of prompted tokens
self.prompt_dropout = nn.Dropout(dropout)
pool_size = pool.get('size', 0)
self.pool_length = pool.get('length', 1)
self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False
if self.use_bank:
print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})')
if prompt_dim != embed_dim:
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
else:
self.prompt_inproj = nn.Identity()
if self.use_bank:
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
nn.init.kaiming_normal_(
self.prompt_outproj.weight, a=0, mode='fan_out')
else:
self.prompt_outproj = nn.Identity()
self.split_st = split_st # split spatial and temporal prompts
# initiate prompt:
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim))
if split_st:
if self.use_bank:
pool['size'] //= 2
self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.spatial_prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens // 2, prompt_dim))
self.temporal_prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens // 2, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val)
nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val)
else:
if self.use_bank:
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.prompt_embeddings.data, -val, val)
self.deep = deep or deep_shared
self.deep_shared = deep_shared
if deep and (not deep_shared):
total_d_layer = num_layers - 1
if split_st:
if self.use_bank:
self.spatial_deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
self.temporal_deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
else:
self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens // 2, prompt_dim))
self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens // 2, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val)
nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val)
else:
if self.use_bank:
self.deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
else:
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
def forward(self, query=None, layer=0, istrain=False, gamma=1.0):
query = query.detach()
query = self.prompt_inproj(query)
ps_loss = query.new_zeros([1])
if self.split_st:
if self.deep and (not self.deep_shared) and layer > 0:
if self.use_bank:
k = (self.num_tokens // 2) // self.pool_length
spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma)
spatial_prompts = spatial_out['prompts']
temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma)
temporal_prompts = temporal_out['prompts']
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
else:
spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1]
temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1]
else:
if self.use_bank:
k = (self.num_tokens // 2) // self.pool_length
spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma)
spatial_prompts = spatial_out['prompts']
temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma)
temporal_prompts = temporal_out['prompts']
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
else:
spatial_prompts = self.spatial_prompt_embeddings
temporal_prompts = self.temporal_prompt_embeddings
prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1)
else:
if self.deep and (not self.deep_shared) and layer > 0:
if self.use_bank:
k = self.num_tokens // self.pool_length
out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma)
prompts = out['prompts']
ps_loss += out.get('ps_loss', 0)
else:
prompts = self.deep_prompt_embeddings[layer-1]
else:
if self.use_bank:
k = self.num_tokens // self.pool_length
out = self.prompt_pool(query, k, istrain, gamma)
prompts = out['prompts']
ps_loss += out.get('ps_loss', 0)
else:
prompts = self.prompt_embeddings
prompts = self.prompt_dropout(self.prompt_outproj(prompts))
return prompts, ps_loss
class CMM(nn.Module):
'''Context modeling module'''
def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}):
super(CMM, self).__init__()
self.num_tokens = num_tokens
self.num_frames = num_frames
self.embed_dim = embed_dim
self.prompt_dim = prompt_dim
self.pool_size = pool.get('size', 0)
self.pool_length = pool.get('length', 1)
self.use_bank = True if self.pool_size > 0 else False
self.use_rnn = not self.use_bank
if self.use_rnn:
self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim,
num_layers=1, batch_first=True, dropout=dropout, bidirectional=True)
self.shared = shared
self.prompt_dropout = nn.Dropout(dropout)
if self.use_bank:
print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})')
if self.use_rnn:
self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim)
nn.init.kaiming_normal_(
self.prompt_inproj.weight, a=0, mode='fan_out')
else:
if embed_dim != prompt_dim:
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
else:
self.prompt_inproj = nn.Identity()
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
nn.init.kaiming_normal_(
self.prompt_outproj.weight, a=0, mode='fan_out')
if shared:
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(num_layer)])
else:
self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens)
def forward(self, x, layer=0, istrain=False, gamma=1.0):
BZ = x.size(0)
x = x.detach()
x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames)
x = torch.mean(x, dim=2)
if self.use_rnn:
x, _ = self.rnn(x)
ps_loss = x.new_zeros([1])
if self.use_bank:
query = self.prompt_inproj(x).flatten(0, 1)
k = self.num_tokens // self.pool_length
if self.shared:
out = self.prompt_pool(query, k, istrain, gamma)
else:
out = self.prompt_pool[layer](query, k, istrain, gamma)
prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames)
prompts = self.prompt_outproj(prompts)
ps_loss += out.get('ps_loss', 0) * self.num_frames
else:
prompts = self.fc(x)
prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens)
return prompts, ps_loss