|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import repeat |
|
|
|
from models.base import CaptionMetaMixin |
|
from utils.model_util import init |
|
|
|
|
|
class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin): |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
shared_dim: int, |
|
tchr_layer_to_dims: Dict[str, int], |
|
loss_type: str = "mse",): |
|
super().__init__() |
|
self.model = model |
|
self.tchr_layers = list(tchr_layer_to_dims.keys()) |
|
self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size, |
|
2 * shared_dim) |
|
self.stdnt_qv_proj.apply(init) |
|
for layer, dim in tchr_layer_to_dims.items(): |
|
self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim)) |
|
getattr(self, f'tchr_kv_proj_{layer}').apply(init) |
|
if loss_type == "mse": |
|
self.loss_fn = nn.MSELoss(reduction="none") |
|
|
|
def forward(self, input_dict: Dict): |
|
output_dict = self.model(input_dict) |
|
if "tchr_output" in input_dict: |
|
stdnt_emb = output_dict["fc_emb"] |
|
stdnt_qv = self.stdnt_qv_proj(stdnt_emb) |
|
stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1) |
|
|
|
tchr_output = input_dict["tchr_output"] |
|
layer_ks, layer_vs = [], [] |
|
for layer in self.tchr_layers: |
|
layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer]) |
|
layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1) |
|
layer_ks.append(layer_k) |
|
layer_vs.append(layer_v) |
|
layer_ks = torch.stack(layer_ks, dim=1) |
|
layer_vs = torch.stack(layer_vs, dim=1) |
|
weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1) |
|
stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers)) |
|
loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True) |
|
loss = (weights @ loss).mean() |
|
output_dict["enc_kd_loss"] = loss |
|
return output_dict |
|
|
|
|
|
class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin): |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
shared_dim: int, |
|
tchr_dim: int, |
|
use_tchr_proj: bool = True, |
|
l2_norm: bool = False, |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.use_tchr_proj = use_tchr_proj |
|
if not use_tchr_proj: |
|
assert shared_dim == tchr_dim |
|
self.tchr_dim = tchr_dim |
|
self.l2_norm = l2_norm |
|
if hasattr(model, "encoder"): |
|
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, |
|
shared_dim) |
|
else: |
|
self.stdnt_proj = nn.Linear(model.fc_emb_size, |
|
shared_dim) |
|
self.stdnt_proj.apply(init) |
|
if use_tchr_proj: |
|
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) |
|
self.tchr_proj.apply(init) |
|
else: |
|
self.tchr_proj = nn.Identity() |
|
|
|
def forward(self, input_dict: Dict): |
|
unsup = input_dict.get("unsup", False) |
|
if unsup is False: |
|
if self.use_tchr_proj: |
|
output_dict = self.model(input_dict) |
|
stdnt_emb = output_dict["fc_emb"] |
|
else: |
|
encoder_output = self.model.encoder(input_dict) |
|
stdnt_emb = encoder_output["fc_emb"] |
|
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) |
|
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) |
|
output_dict = self.model.forward_decoder(input_dict, encoder_output) |
|
else: |
|
output_dict = self.model.encoder(input_dict) |
|
stdnt_emb = output_dict["fc_emb"] |
|
if "tchr_output" in input_dict: |
|
stdnt_emb = self.stdnt_proj(stdnt_emb) |
|
tchr_emb = input_dict["tchr_output"]["embedding"] |
|
thcr_emb = self.tchr_proj(tchr_emb) |
|
|
|
if self.l2_norm: |
|
stdnt_emb = F.normalize(stdnt_emb, dim=-1) |
|
thcr_emb = F.normalize(thcr_emb, dim=-1) |
|
|
|
loss = F.mse_loss(stdnt_emb, thcr_emb) |
|
output_dict["enc_kd_loss"] = loss |
|
return output_dict |
|
|
|
|
|
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
shared_dim: int, |
|
tchr_dim: int, |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.tchr_dim = tchr_dim |
|
if hasattr(model, "encoder"): |
|
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, |
|
shared_dim) |
|
else: |
|
self.stdnt_proj = nn.Linear(model.fc_emb_size, |
|
shared_dim) |
|
self.stdnt_proj.apply(init) |
|
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) |
|
self.tchr_proj.apply(init) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
def forward(self, input_dict: Dict): |
|
unsup = input_dict.get("unsup", False) |
|
if unsup is False: |
|
output_dict = self.model(input_dict) |
|
else: |
|
output_dict = self.model.encoder(input_dict) |
|
if "tchr_output" in input_dict: |
|
stdnt_emb = output_dict["fc_emb"] |
|
stdnt_emb = self.stdnt_proj(stdnt_emb) |
|
tchr_emb = input_dict["tchr_output"]["embedding"] |
|
thcr_emb = self.tchr_proj(tchr_emb) |
|
|
|
stdnt_emb = F.normalize(stdnt_emb, dim=-1) |
|
thcr_emb = F.normalize(thcr_emb, dim=-1) |
|
|
|
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) |
|
logit = self.logit_scale * unscaled_logit |
|
label = torch.arange(logit.shape[0]).to(logit.device) |
|
loss1 = F.cross_entropy(logit, label) |
|
loss2 = F.cross_entropy(logit.transpose(0, 1), label) |
|
loss = (loss1 + loss2) / 2 |
|
output_dict["enc_kd_loss"] = loss |
|
return output_dict |
|
|
|
|
|
class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin): |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
shared_dim: int, |
|
tchr_dim: int, |
|
use_tchr_proj: bool = True, |
|
l2_norm: bool = False, |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.use_tchr_proj = use_tchr_proj |
|
if not use_tchr_proj: |
|
assert shared_dim == tchr_dim |
|
self.tchr_dim = tchr_dim |
|
self.l2_norm = l2_norm |
|
if hasattr(model, "encoder"): |
|
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, |
|
shared_dim) |
|
else: |
|
self.stdnt_proj = nn.Linear(model.fc_emb_size, |
|
shared_dim) |
|
self.stdnt_proj.apply(init) |
|
if use_tchr_proj: |
|
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) |
|
self.tchr_proj.apply(init) |
|
else: |
|
self.tchr_proj = nn.Identity() |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
def forward(self, input_dict: Dict): |
|
unsup = input_dict.get("unsup", False) |
|
if unsup is False: |
|
if self.use_tchr_proj: |
|
output_dict = self.model(input_dict) |
|
stdnt_emb = output_dict["fc_emb"] |
|
else: |
|
encoder_output = self.model.encoder(input_dict) |
|
stdnt_emb = encoder_output["fc_emb"] |
|
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) |
|
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) |
|
output_dict = self.model.forward_decoder(input_dict, encoder_output) |
|
else: |
|
output_dict = self.model.encoder(input_dict) |
|
stdnt_emb = output_dict["fc_emb"] |
|
if "tchr_output" in input_dict: |
|
stdnt_emb = self.stdnt_proj(stdnt_emb) |
|
tchr_emb = input_dict["tchr_output"]["embedding"] |
|
thcr_emb = self.tchr_proj(tchr_emb) |
|
|
|
if self.l2_norm: |
|
stdnt_emb = F.normalize(stdnt_emb, dim=-1) |
|
thcr_emb = F.normalize(thcr_emb, dim=-1) |
|
|
|
mse_loss = F.mse_loss(stdnt_emb, thcr_emb) |
|
|
|
stdnt_emb = F.normalize(stdnt_emb, dim=-1) |
|
thcr_emb = F.normalize(thcr_emb, dim=-1) |
|
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) |
|
logit = self.logit_scale * unscaled_logit |
|
label = torch.arange(logit.shape[0]).to(logit.device) |
|
loss1 = F.cross_entropy(logit, label) |
|
loss2 = F.cross_entropy(logit.transpose(0, 1), label) |
|
cntr_loss = (loss1 + loss2) / 2 |
|
output_dict["enc_kd_loss"] = mse_loss + cntr_loss |
|
|
|
return output_dict |
|
|