CharadesEgo / lavila /models /prompt_tuning.py
gina9726's picture
Update lavila/models/prompt_tuning.py
4d22e10 verified
import math
from functools import reduce
from operator import mul
from einops import rearrange, repeat
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
class PromptLearner(nn.Module):
def __init__(self, ctx_dim=512, n_ctx=16):
super(PromptLearner, self).__init__()
self.n_ctx = n_ctx
self.ctx_dim = ctx_dim
# initialize prompts
ctx_vectors = torch.empty(n_ctx, ctx_dim)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
def forward(self):
return self.ctx
class PromptPoolLearner(nn.Module):
def __init__(self, prompt_dim=256, size=128, length=1):
super(PromptPoolLearner, self).__init__()
self.prompt_dim = prompt_dim
self.length = length
self.size = size
# initiate prompt
self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim))
self.id_table = torch.ones([size]).cuda()
# xavier_uniform initialization
nn.init.uniform_(self.prompt_values.data, -1, 1)
def l2_normalize(self, x, dim=None, epsilon=1e-12):
"""Normalizes a given vector or matrix."""
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
return x * x_inv_norm
def forward(self, query, k=0, istrain=False, gamma=1.0):
BZ = query.shape[0]
out = dict()
query = self.l2_normalize(query.squeeze(1), dim=1)
keys = self.prompt_values.mean(dim=1)
keys = self.l2_normalize(keys, dim=1)
similarity = torch.matmul(query, keys.t())
if k > 0 and k < self.size:
if istrain:
inv_freq = self.id_table.sum() / self.id_table.float()
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
idx = torch.multinomial(weights, k, replacement=False)
else:
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
self.id_table[prompt_id] += id_counts
prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim)
else:
idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1)
prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1)
prompts = self.l2_normalize(prompts, dim=-1)
out['prompts'] = prompts
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
out['ps_loss'] = diff + ksim
return out
class VisualPromptLearner(nn.Module):
def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False,
deep_shared=False, split_st=False, dropout=0.1, pool={}):
super(VisualPromptLearner, self).__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.prompt_dim = prompt_dim
self.num_tokens = num_tokens # number of prompted tokens
self.prompt_dropout = nn.Dropout(dropout)
pool_size = pool.get('size', 0)
self.pool_length = pool.get('length', 1)
self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False
if self.use_bank:
print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})')
if prompt_dim != embed_dim:
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
else:
self.prompt_inproj = nn.Identity()
if self.use_bank:
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
nn.init.kaiming_normal_(
self.prompt_outproj.weight, a=0, mode='fan_out')
else:
self.prompt_outproj = nn.Identity()
self.split_st = split_st # split spatial and temporal prompts
# initiate prompt:
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim))
if split_st:
if self.use_bank:
pool['size'] //= 2
self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.spatial_prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens // 2, prompt_dim))
self.temporal_prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens // 2, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val)
nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val)
else:
if self.use_bank:
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.prompt_embeddings = nn.Parameter(torch.zeros(
1, num_tokens, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.prompt_embeddings.data, -val, val)
self.deep = deep or deep_shared
self.deep_shared = deep_shared
if deep and (not deep_shared):
total_d_layer = num_layers - 1
if split_st:
if self.use_bank:
self.spatial_deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
self.temporal_deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
else:
self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens // 2, prompt_dim))
self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens // 2, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val)
nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val)
else:
if self.use_bank:
self.deep_prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(total_d_layer)])
else:
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
total_d_layer, num_tokens, prompt_dim))
# xavier_uniform initialization
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
def forward(self, query=None, layer=0, istrain=False, gamma=1.0):
query = query.detach()
query = self.prompt_inproj(query)
ps_loss = query.new_zeros([1])
if self.split_st:
if self.deep and (not self.deep_shared) and layer > 0:
if self.use_bank:
k = (self.num_tokens // 2) // self.pool_length
spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma)
spatial_prompts = spatial_out['prompts']
temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma)
temporal_prompts = temporal_out['prompts']
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
else:
spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1]
temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1]
else:
if self.use_bank:
k = (self.num_tokens // 2) // self.pool_length
spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma)
spatial_prompts = spatial_out['prompts']
temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma)
temporal_prompts = temporal_out['prompts']
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
else:
spatial_prompts = self.spatial_prompt_embeddings
temporal_prompts = self.temporal_prompt_embeddings
prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1)
else:
if self.deep and (not self.deep_shared) and layer > 0:
if self.use_bank:
k = self.num_tokens // self.pool_length
out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma)
prompts = out['prompts']
ps_loss += out.get('ps_loss', 0)
else:
prompts = self.deep_prompt_embeddings[layer-1]
else:
if self.use_bank:
k = self.num_tokens // self.pool_length
out = self.prompt_pool(query, k, istrain, gamma)
prompts = out['prompts']
ps_loss += out.get('ps_loss', 0)
else:
prompts = self.prompt_embeddings
prompts = self.prompt_dropout(self.prompt_outproj(prompts))
return prompts, ps_loss
class CMM(nn.Module):
'''Context modeling module'''
def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}):
super(CMM, self).__init__()
self.num_tokens = num_tokens
self.num_frames = num_frames
self.embed_dim = embed_dim
self.prompt_dim = prompt_dim
self.pool_size = pool.get('size', 0)
self.pool_length = pool.get('length', 1)
self.use_bank = True if self.pool_size > 0 else False
self.use_rnn = not self.use_bank
if self.use_rnn:
self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim,
num_layers=1, batch_first=True, dropout=dropout, bidirectional=True)
self.shared = shared
self.prompt_dropout = nn.Dropout(dropout)
if self.use_bank:
print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})')
if self.use_rnn:
self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim)
nn.init.kaiming_normal_(
self.prompt_inproj.weight, a=0, mode='fan_out')
else:
if embed_dim != prompt_dim:
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
else:
self.prompt_inproj = nn.Identity()
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
nn.init.kaiming_normal_(
self.prompt_outproj.weight, a=0, mode='fan_out')
if shared:
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
else:
self.prompt_pool = nn.ModuleList([
PromptPoolLearner(prompt_dim, **pool)
for i in range(num_layer)])
else:
self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens)
def forward(self, x, layer=0, istrain=False, gamma=1.0):
BZ = x.size(0)
x = x.detach()
x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames)
x = torch.mean(x, dim=2)
if self.use_rnn:
x, _ = self.rnn(x)
ps_loss = x.new_zeros([1])
if self.use_bank:
query = self.prompt_inproj(x).flatten(0, 1)
k = self.num_tokens // self.pool_length
if self.shared:
out = self.prompt_pool(query, k, istrain, gamma)
else:
out = self.prompt_pool[layer](query, k, istrain, gamma)
prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames)
prompts = self.prompt_outproj(prompts)
ps_loss += out.get('ps_loss', 0) * self.num_frames
else:
prompts = self.fc(x)
prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens)
return prompts, ps_loss