Spaces:
Sleeping
Sleeping
import copy | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | |
""" | |
Shift input ids one token to the right. | |
""" | |
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
shifted_input_ids[:, 0] = decoder_start_token_id | |
if pad_token_id is None: | |
raise ValueError("self.model.config.pad_token_id has to be defined.") | |
# replace possible -100 values in labels by `pad_token_id` | |
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
return shifted_input_ids | |
def shift_embed_right(input_embed: torch.Tensor, pad_token_id: int, decoder_start_token_embed: torch.Tensor): | |
""" | |
Shift input embeddings one column to the right. | |
""" | |
shifted_input_embed = input_embed.new_zeros(input_embed.shape) | |
shifted_input_embed[:, 1:, :] = input_embed[:, :-1, :].clone() | |
shifted_input_embed[:, 0, :] = decoder_start_token_embed | |
if pad_token_id is None: | |
raise ValueError("self.model.config.pad_token_id has to be defined.") | |
return shifted_input_embed | |
class BARTAutoEncoder(torch.nn.Module): | |
def __init__(self, summarizor, expandor, device): | |
super().__init__() | |
self.summarizor = summarizor.to(device) | |
self.expandor = expandor.to(device) | |
self.config = summarizor.config | |
self.start_token_embed = summarizor.model.shared(torch.tensor([2]).to(device)).detach().clone().requires_grad_(False) | |
self.shared_weight = self.summarizor.model.shared.weight.detach().clone().requires_grad_(False) | |
self.device = device | |
def forward(self, input): | |
summ_input = dict() | |
summ_input['input_ids'] = input['summ_input_ids'].to(self.device) | |
summ_input['attention_mask'] = input['summ_attention_mask'].to(self.device) | |
outputs_summ = self.summarizor(**summ_input) | |
outputs_summ_logits = outputs_summ.logits @ self.shared_weight | |
decoder_embed = self.expandor.model.shared(input['exp_decoder_ids']).detach().clone() | |
outputs_exp = self.expandor(inputs_embeds = outputs_summ_logits, | |
attention_mask = input['exp_attention_mask'], | |
decoder_inputs_embeds = shift_embed_right(decoder_embed, 1, self.start_token_embed) | |
) | |
return outputs_exp.logits | |
class BartAutoEncoderDataset(Dataset): | |
def __init__(self, article, tokenizer, shared, start_embed, device): | |
input_len = len(article) | |
inputs = tokenizer(article, | |
return_tensors="pt", | |
padding = "longest") | |
dl_input = dict() | |
dl_input['summ_attention_mask'] = inputs['attention_mask'][:input_len].to(device) | |
dl_input['summ_input_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device) | |
dl_input['exp_decoder_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device) | |
dl_input['exp_attention_mask'] = inputs['attention_mask'][input_len:].to(device) | |
outputs = inputs['input_ids'][:input_len].to(device) | |
self.data = dl_input | |
self.targets = outputs | |
def __getitem__(self, index): | |
target_embed = self.targets[index] | |
summ_input_ids = self.data['summ_input_ids'][index] | |
summ_decoder_ids = self.data['summ_decoder_ids'][index] | |
summ_attention_mask = self.data['summ_attention_mask'][index] | |
exp_attention_mask = self.data['exp_attention_mask'][index] | |
exp_decoder_ids = self.data['exp_decoder_ids'][index] | |
return {'summ_input_ids': summ_input_ids, | |
'summ_attention_mask': summ_attention_mask, | |
'exp_attention_mask': exp_attention_mask, | |
'exp_decoder_ids': exp_decoder_ids | |
}, target_embed | |
def __len__(self): | |
return len(self.data) | |