wsntxxn
Add AudioCaps checkpoint
6065472
raw
history blame
9.03 kB
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
from models.base import CaptionMetaMixin
from utils.model_util import init
class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin):
def __init__(self,
model: nn.Module,
shared_dim: int,
tchr_layer_to_dims: Dict[str, int],
loss_type: str = "mse",):
super().__init__()
self.model = model
self.tchr_layers = list(tchr_layer_to_dims.keys())
self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size,
2 * shared_dim)
self.stdnt_qv_proj.apply(init)
for layer, dim in tchr_layer_to_dims.items():
self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim))
getattr(self, f'tchr_kv_proj_{layer}').apply(init)
if loss_type == "mse":
self.loss_fn = nn.MSELoss(reduction="none")
def forward(self, input_dict: Dict):
output_dict = self.model(input_dict)
if "tchr_output" in input_dict:
stdnt_emb = output_dict["fc_emb"]
stdnt_qv = self.stdnt_qv_proj(stdnt_emb)
stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1)
tchr_output = input_dict["tchr_output"]
layer_ks, layer_vs = [], []
for layer in self.tchr_layers:
layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer])
layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1)
layer_ks.append(layer_k)
layer_vs.append(layer_v)
layer_ks = torch.stack(layer_ks, dim=1)
layer_vs = torch.stack(layer_vs, dim=1)
weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1)
stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers))
loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True)
loss = (weights @ loss).mean()
output_dict["enc_kd_loss"] = loss
return output_dict
class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
def __init__(self,
model: nn.Module,
shared_dim: int,
tchr_dim: int,
use_tchr_proj: bool = True,
l2_norm: bool = False,
):
super().__init__()
self.model = model
self.use_tchr_proj = use_tchr_proj
if not use_tchr_proj:
assert shared_dim == tchr_dim
self.tchr_dim = tchr_dim
self.l2_norm = l2_norm
if hasattr(model, "encoder"):
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
shared_dim)
else:
self.stdnt_proj = nn.Linear(model.fc_emb_size,
shared_dim)
self.stdnt_proj.apply(init)
if use_tchr_proj:
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
self.tchr_proj.apply(init)
else:
self.tchr_proj = nn.Identity()
def forward(self, input_dict: Dict):
unsup = input_dict.get("unsup", False)
if unsup is False:
if self.use_tchr_proj:
output_dict = self.model(input_dict)
stdnt_emb = output_dict["fc_emb"]
else:
encoder_output = self.model.encoder(input_dict)
stdnt_emb = encoder_output["fc_emb"]
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
output_dict = self.model.forward_decoder(input_dict, encoder_output)
else:
output_dict = self.model.encoder(input_dict)
stdnt_emb = output_dict["fc_emb"]
if "tchr_output" in input_dict:
stdnt_emb = self.stdnt_proj(stdnt_emb)
tchr_emb = input_dict["tchr_output"]["embedding"]
thcr_emb = self.tchr_proj(tchr_emb)
if self.l2_norm:
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
thcr_emb = F.normalize(thcr_emb, dim=-1)
loss = F.mse_loss(stdnt_emb, thcr_emb)
output_dict["enc_kd_loss"] = loss
return output_dict
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
def __init__(self,
model: nn.Module,
shared_dim: int,
tchr_dim: int,
):
super().__init__()
self.model = model
self.tchr_dim = tchr_dim
if hasattr(model, "encoder"):
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
shared_dim)
else:
self.stdnt_proj = nn.Linear(model.fc_emb_size,
shared_dim)
self.stdnt_proj.apply(init)
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
self.tchr_proj.apply(init)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, input_dict: Dict):
unsup = input_dict.get("unsup", False)
if unsup is False:
output_dict = self.model(input_dict)
else:
output_dict = self.model.encoder(input_dict)
if "tchr_output" in input_dict:
stdnt_emb = output_dict["fc_emb"]
stdnt_emb = self.stdnt_proj(stdnt_emb)
tchr_emb = input_dict["tchr_output"]["embedding"]
thcr_emb = self.tchr_proj(tchr_emb)
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
thcr_emb = F.normalize(thcr_emb, dim=-1)
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
logit = self.logit_scale * unscaled_logit
label = torch.arange(logit.shape[0]).to(logit.device)
loss1 = F.cross_entropy(logit, label)
loss2 = F.cross_entropy(logit.transpose(0, 1), label)
loss = (loss1 + loss2) / 2
output_dict["enc_kd_loss"] = loss
return output_dict
class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin):
def __init__(self,
model: nn.Module,
shared_dim: int,
tchr_dim: int,
use_tchr_proj: bool = True,
l2_norm: bool = False,
):
super().__init__()
self.model = model
self.use_tchr_proj = use_tchr_proj
if not use_tchr_proj:
assert shared_dim == tchr_dim
self.tchr_dim = tchr_dim
self.l2_norm = l2_norm
if hasattr(model, "encoder"):
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
shared_dim)
else:
self.stdnt_proj = nn.Linear(model.fc_emb_size,
shared_dim)
self.stdnt_proj.apply(init)
if use_tchr_proj:
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
self.tchr_proj.apply(init)
else:
self.tchr_proj = nn.Identity()
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, input_dict: Dict):
unsup = input_dict.get("unsup", False)
if unsup is False:
if self.use_tchr_proj:
output_dict = self.model(input_dict)
stdnt_emb = output_dict["fc_emb"]
else:
encoder_output = self.model.encoder(input_dict)
stdnt_emb = encoder_output["fc_emb"]
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"])
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"])
output_dict = self.model.forward_decoder(input_dict, encoder_output)
else:
output_dict = self.model.encoder(input_dict)
stdnt_emb = output_dict["fc_emb"]
if "tchr_output" in input_dict:
stdnt_emb = self.stdnt_proj(stdnt_emb)
tchr_emb = input_dict["tchr_output"]["embedding"]
thcr_emb = self.tchr_proj(tchr_emb)
if self.l2_norm:
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
thcr_emb = F.normalize(thcr_emb, dim=-1)
mse_loss = F.mse_loss(stdnt_emb, thcr_emb)
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
thcr_emb = F.normalize(thcr_emb, dim=-1)
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
logit = self.logit_scale * unscaled_logit
label = torch.arange(logit.shape[0]).to(logit.device)
loss1 = F.cross_entropy(logit, label)
loss2 = F.cross_entropy(logit.transpose(0, 1), label)
cntr_loss = (loss1 + loss2) / 2
output_dict["enc_kd_loss"] = mse_loss + cntr_loss
return output_dict