|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models import BaseDecoder |
|
from utils.model_util import generate_length_mask, PositionalEncoding |
|
from utils.train_util import merge_load_state_dict |
|
|
|
|
|
class TransformerDecoder(BaseDecoder): |
|
|
|
def __init__(self, |
|
emb_dim, |
|
vocab_size, |
|
fc_emb_dim, |
|
attn_emb_dim, |
|
dropout, |
|
freeze=False, |
|
tie_weights=False, |
|
**kwargs): |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout=dropout, tie_weights=tie_weights) |
|
self.d_model = emb_dim |
|
self.nhead = kwargs.get("nhead", self.d_model // 64) |
|
self.nlayers = kwargs.get("nlayers", 2) |
|
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) |
|
|
|
self.pos_encoder = PositionalEncoding(self.d_model, dropout) |
|
layer = nn.TransformerDecoderLayer(d_model=self.d_model, |
|
nhead=self.nhead, |
|
dim_feedforward=self.dim_feedforward, |
|
dropout=dropout) |
|
self.model = nn.TransformerDecoder(layer, self.nlayers) |
|
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False) |
|
if tie_weights: |
|
self.classifier.weight = self.word_embedding.weight |
|
self.attn_proj = nn.Sequential( |
|
nn.Linear(self.attn_emb_dim, self.d_model), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.LayerNorm(self.d_model) |
|
) |
|
self.init_params() |
|
|
|
self.freeze = freeze |
|
if freeze: |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
|
|
def init_params(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def load_pretrained(self, pretrained, output_fn): |
|
checkpoint = torch.load(pretrained, map_location="cpu") |
|
|
|
if "model" in checkpoint: |
|
checkpoint = checkpoint["model"] |
|
if next(iter(checkpoint)).startswith("decoder."): |
|
state_dict = {} |
|
for k, v in checkpoint.items(): |
|
state_dict[k[8:]] = v |
|
|
|
loaded_keys = merge_load_state_dict(state_dict, self, output_fn) |
|
if self.freeze: |
|
for name, param in self.named_parameters(): |
|
if name in loaded_keys: |
|
param.requires_grad = False |
|
else: |
|
param.requires_grad = True |
|
|
|
|
|
def generate_square_subsequent_mask(self, max_length): |
|
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) |
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
|
return mask |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_len = input_dict["attn_emb_len"] |
|
cap_padding_mask = input_dict["cap_padding_mask"] |
|
|
|
p_attn_emb = self.attn_proj(attn_emb) |
|
p_attn_emb = p_attn_emb.transpose(0, 1) |
|
word = word.to(attn_emb.device) |
|
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
|
embed = embed.transpose(0, 1) |
|
embed = self.pos_encoder(embed) |
|
|
|
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
|
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
|
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
|
tgt_key_padding_mask=cap_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask) |
|
output = output.transpose(0, 1) |
|
output = { |
|
"embed": output, |
|
"logit": self.classifier(output), |
|
} |
|
return output |
|
|
|
|
|
class M2TransformerDecoder(BaseDecoder): |
|
|
|
def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs): |
|
super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,) |
|
try: |
|
from m2transformer.models.transformer import MeshedDecoder |
|
except: |
|
raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`") |
|
del self.word_embedding |
|
del self.in_dropout |
|
|
|
self.d_model = attn_emb_dim |
|
self.nhead = kwargs.get("nhead", self.d_model // 64) |
|
self.nlayers = kwargs.get("nlayers", 2) |
|
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) |
|
self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0, |
|
d_model=self.d_model, |
|
h=self.nhead, |
|
d_ff=self.dim_feedforward, |
|
dropout=dropout) |
|
self.init_params() |
|
|
|
def init_params(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_mask = input_dict["attn_emb_mask"] |
|
word = word.to(attn_emb.device) |
|
embed, logit = self.model(word, attn_emb, attn_emb_mask) |
|
output = { |
|
"embed": embed, |
|
"logit": logit, |
|
} |
|
return output |
|
|
|
|
|
class EventTransformerDecoder(TransformerDecoder): |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_len = input_dict["attn_emb_len"] |
|
cap_padding_mask = input_dict["cap_padding_mask"] |
|
event_emb = input_dict["event"] |
|
|
|
p_attn_emb = self.attn_proj(attn_emb) |
|
p_attn_emb = p_attn_emb.transpose(0, 1) |
|
word = word.to(attn_emb.device) |
|
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
|
|
|
embed = embed.transpose(0, 1) |
|
embed += event_emb |
|
embed = self.pos_encoder(embed) |
|
|
|
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
|
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
|
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
|
tgt_key_padding_mask=cap_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask) |
|
output = output.transpose(0, 1) |
|
output = { |
|
"embed": output, |
|
"logit": self.classifier(output), |
|
} |
|
return output |
|
|
|
|
|
class KeywordProbTransformerDecoder(TransformerDecoder): |
|
|
|
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, keyword_classes_num, **kwargs): |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, **kwargs) |
|
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model) |
|
self.word_keyword_norm = nn.LayerNorm(self.d_model) |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_len = input_dict["attn_emb_len"] |
|
cap_padding_mask = input_dict["cap_padding_mask"] |
|
keyword = input_dict["keyword"] |
|
|
|
p_attn_emb = self.attn_proj(attn_emb) |
|
p_attn_emb = p_attn_emb.transpose(0, 1) |
|
word = word.to(attn_emb.device) |
|
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
|
|
|
embed = embed.transpose(0, 1) |
|
embed += self.keyword_proj(keyword) |
|
embed = self.word_keyword_norm(embed) |
|
|
|
embed = self.pos_encoder(embed) |
|
|
|
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
|
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
|
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
|
tgt_key_padding_mask=cap_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask) |
|
output = output.transpose(0, 1) |
|
output = { |
|
"embed": output, |
|
"logit": self.classifier(output), |
|
} |
|
return output |
|
|