doc-to-slides / BartSE.py
com3dian's picture
Create BartSE.py
7ad4088 verified
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def shift_embed_right(input_embed: torch.Tensor, pad_token_id: int, decoder_start_token_embed: torch.Tensor):
"""
Shift input embeddings one column to the right.
"""
shifted_input_embed = input_embed.new_zeros(input_embed.shape)
shifted_input_embed[:, 1:, :] = input_embed[:, :-1, :].clone()
shifted_input_embed[:, 0, :] = decoder_start_token_embed
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
return shifted_input_embed
class BARTAutoEncoder(torch.nn.Module):
def __init__(self, summarizor, expandor, device):
super().__init__()
self.summarizor = summarizor.to(device)
self.expandor = expandor.to(device)
self.config = summarizor.config
self.start_token_embed = summarizor.model.shared(torch.tensor([2]).to(device)).detach().clone().requires_grad_(False)
self.shared_weight = self.summarizor.model.shared.weight.detach().clone().requires_grad_(False)
self.device = device
def forward(self, input):
summ_input = dict()
summ_input['input_ids'] = input['summ_input_ids'].to(self.device)
summ_input['attention_mask'] = input['summ_attention_mask'].to(self.device)
outputs_summ = self.summarizor(**summ_input)
outputs_summ_logits = outputs_summ.logits @ self.shared_weight
decoder_embed = self.expandor.model.shared(input['exp_decoder_ids']).detach().clone()
outputs_exp = self.expandor(inputs_embeds = outputs_summ_logits,
attention_mask = input['exp_attention_mask'],
decoder_inputs_embeds = shift_embed_right(decoder_embed, 1, self.start_token_embed)
)
return outputs_exp.logits
class BartAutoEncoderDataset(Dataset):
def __init__(self, article, tokenizer, shared, start_embed, device):
input_len = len(article)
inputs = tokenizer(article,
return_tensors="pt",
padding = "longest")
dl_input = dict()
dl_input['summ_attention_mask'] = inputs['attention_mask'][:input_len].to(device)
dl_input['summ_input_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
dl_input['exp_decoder_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
dl_input['exp_attention_mask'] = inputs['attention_mask'][input_len:].to(device)
outputs = inputs['input_ids'][:input_len].to(device)
self.data = dl_input
self.targets = outputs
def __getitem__(self, index):
target_embed = self.targets[index]
summ_input_ids = self.data['summ_input_ids'][index]
summ_decoder_ids = self.data['summ_decoder_ids'][index]
summ_attention_mask = self.data['summ_attention_mask'][index]
exp_attention_mask = self.data['exp_attention_mask'][index]
exp_decoder_ids = self.data['exp_decoder_ids'][index]
return {'summ_input_ids': summ_input_ids,
'summ_attention_mask': summ_attention_mask,
'exp_attention_mask': exp_attention_mask,
'exp_decoder_ids': exp_decoder_ids
}, target_embed
def __len__(self):
return len(self.data)