import math import torch import torch.nn as nn from models import BaseDecoder from utils.model_util import generate_length_mask, PositionalEncoding from utils.train_util import merge_load_state_dict class TransformerDecoder(BaseDecoder): def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, freeze=False, tie_weights=False, **kwargs): super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout, tie_weights=tie_weights) self.d_model = emb_dim self.nhead = kwargs.get("nhead", self.d_model // 64) self.nlayers = kwargs.get("nlayers", 2) self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) self.pos_encoder = PositionalEncoding(self.d_model, dropout) layer = nn.TransformerDecoderLayer(d_model=self.d_model, nhead=self.nhead, dim_feedforward=self.dim_feedforward, dropout=dropout) self.model = nn.TransformerDecoder(layer, self.nlayers) self.classifier = nn.Linear(self.d_model, vocab_size, bias=False) if tie_weights: self.classifier.weight = self.word_embedding.weight self.attn_proj = nn.Sequential( nn.Linear(self.attn_emb_dim, self.d_model), nn.ReLU(), nn.Dropout(dropout), nn.LayerNorm(self.d_model) ) self.init_params() self.freeze = freeze if freeze: for p in self.parameters(): p.requires_grad = False def init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def load_pretrained(self, pretrained, output_fn): checkpoint = torch.load(pretrained, map_location="cpu") if "model" in checkpoint: checkpoint = checkpoint["model"] if next(iter(checkpoint)).startswith("decoder."): state_dict = {} for k, v in checkpoint.items(): state_dict[k[8:]] = v loaded_keys = merge_load_state_dict(state_dict, self, output_fn) if self.freeze: for name, param in self.named_parameters(): if name in loaded_keys: param.requires_grad = False else: param.requires_grad = True def generate_square_subsequent_mask(self, max_length): mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, input_dict): word = input_dict["word"] attn_emb = input_dict["attn_emb"] attn_emb_len = input_dict["attn_emb_len"] cap_padding_mask = input_dict["cap_padding_mask"] p_attn_emb = self.attn_proj(attn_emb) p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] word = word.to(attn_emb.device) embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] embed = embed.transpose(0, 1) # [T, N, emb_dim] embed = self.pos_encoder(embed) tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, tgt_key_padding_mask=cap_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = output.transpose(0, 1) output = { "embed": output, "logit": self.classifier(output), } return output class M2TransformerDecoder(BaseDecoder): def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs): super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,) try: from m2transformer.models.transformer import MeshedDecoder except: raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`") del self.word_embedding del self.in_dropout self.d_model = attn_emb_dim self.nhead = kwargs.get("nhead", self.d_model // 64) self.nlayers = kwargs.get("nlayers", 2) self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0, d_model=self.d_model, h=self.nhead, d_ff=self.dim_feedforward, dropout=dropout) self.init_params() def init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, input_dict): word = input_dict["word"] attn_emb = input_dict["attn_emb"] attn_emb_mask = input_dict["attn_emb_mask"] word = word.to(attn_emb.device) embed, logit = self.model(word, attn_emb, attn_emb_mask) output = { "embed": embed, "logit": logit, } return output class EventTransformerDecoder(TransformerDecoder): def forward(self, input_dict): word = input_dict["word"] # index of word embeddings attn_emb = input_dict["attn_emb"] attn_emb_len = input_dict["attn_emb_len"] cap_padding_mask = input_dict["cap_padding_mask"] event_emb = input_dict["event"] # [N, emb_dim] p_attn_emb = self.attn_proj(attn_emb) p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] word = word.to(attn_emb.device) embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] embed = embed.transpose(0, 1) # [T, N, emb_dim] embed += event_emb embed = self.pos_encoder(embed) tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, tgt_key_padding_mask=cap_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = output.transpose(0, 1) output = { "embed": output, "logit": self.classifier(output), } return output class KeywordProbTransformerDecoder(TransformerDecoder): def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, keyword_classes_num, **kwargs): super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs) self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model) self.word_keyword_norm = nn.LayerNorm(self.d_model) def forward(self, input_dict): word = input_dict["word"] # index of word embeddings attn_emb = input_dict["attn_emb"] attn_emb_len = input_dict["attn_emb_len"] cap_padding_mask = input_dict["cap_padding_mask"] keyword = input_dict["keyword"] # [N, keyword_classes_num] p_attn_emb = self.attn_proj(attn_emb) p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] word = word.to(attn_emb.device) embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] embed = embed.transpose(0, 1) # [T, N, emb_dim] embed += self.keyword_proj(keyword) embed = self.word_keyword_norm(embed) embed = self.pos_encoder(embed) tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, tgt_key_padding_mask=cap_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = output.transpose(0, 1) output = { "embed": output, "logit": self.classifier(output), } return output