import copy import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def shift_embed_right(input_embed: torch.Tensor, pad_token_id: int, decoder_start_token_embed: torch.Tensor): """ Shift input embeddings one column to the right. """ shifted_input_embed = input_embed.new_zeros(input_embed.shape) shifted_input_embed[:, 1:, :] = input_embed[:, :-1, :].clone() shifted_input_embed[:, 0, :] = decoder_start_token_embed if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") return shifted_input_embed class BARTAutoEncoder(torch.nn.Module): def __init__(self, summarizor, expandor, device): super().__init__() self.summarizor = summarizor.to(device) self.expandor = expandor.to(device) self.config = summarizor.config self.start_token_embed = summarizor.model.shared(torch.tensor([2]).to(device)).detach().clone().requires_grad_(False) self.shared_weight = self.summarizor.model.shared.weight.detach().clone().requires_grad_(False) self.device = device def forward(self, input): summ_input = dict() summ_input['input_ids'] = input['summ_input_ids'].to(self.device) summ_input['attention_mask'] = input['summ_attention_mask'].to(self.device) outputs_summ = self.summarizor(**summ_input) outputs_summ_logits = outputs_summ.logits @ self.shared_weight decoder_embed = self.expandor.model.shared(input['exp_decoder_ids']).detach().clone() outputs_exp = self.expandor(inputs_embeds = outputs_summ_logits, attention_mask = input['exp_attention_mask'], decoder_inputs_embeds = shift_embed_right(decoder_embed, 1, self.start_token_embed) ) return outputs_exp.logits class BartAutoEncoderDataset(Dataset): def __init__(self, article, tokenizer, shared, start_embed, device): input_len = len(article) inputs = tokenizer(article, return_tensors="pt", padding = "longest") dl_input = dict() dl_input['summ_attention_mask'] = inputs['attention_mask'][:input_len].to(device) dl_input['summ_input_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device) dl_input['exp_decoder_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device) dl_input['exp_attention_mask'] = inputs['attention_mask'][input_len:].to(device) outputs = inputs['input_ids'][:input_len].to(device) self.data = dl_input self.targets = outputs def __getitem__(self, index): target_embed = self.targets[index] summ_input_ids = self.data['summ_input_ids'][index] summ_decoder_ids = self.data['summ_decoder_ids'][index] summ_attention_mask = self.data['summ_attention_mask'][index] exp_attention_mask = self.data['exp_attention_mask'][index] exp_decoder_ids = self.data['exp_decoder_ids'][index] return {'summ_input_ids': summ_input_ids, 'summ_attention_mask': summ_attention_mask, 'exp_attention_mask': exp_attention_mask, 'exp_decoder_ids': exp_decoder_ids }, target_embed def __len__(self): return len(self.data)