File size: 4,086 Bytes
bd6566e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ad4088
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


def shift_embed_right(input_embed: torch.Tensor, pad_token_id: int, decoder_start_token_embed: torch.Tensor):
    """
    Shift input embeddings one column to the right.
    """
    shifted_input_embed = input_embed.new_zeros(input_embed.shape)
    shifted_input_embed[:, 1:, :] = input_embed[:, :-1, :].clone()
    shifted_input_embed[:, 0, :] = decoder_start_token_embed
    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")

    return shifted_input_embed


class BARTAutoEncoder(torch.nn.Module):
    def __init__(self, summarizor, expandor, device):
        super().__init__()
        self.summarizor = summarizor.to(device)
        self.expandor = expandor.to(device)
        self.config = summarizor.config
        self.start_token_embed = summarizor.model.shared(torch.tensor([2]).to(device)).detach().clone().requires_grad_(False)
        self.shared_weight = self.summarizor.model.shared.weight.detach().clone().requires_grad_(False)
        self.device = device

    def forward(self, input):
        summ_input = dict()
        summ_input['input_ids'] = input['summ_input_ids'].to(self.device)
        summ_input['attention_mask'] = input['summ_attention_mask'].to(self.device)
        
        outputs_summ = self.summarizor(**summ_input)
        outputs_summ_logits = outputs_summ.logits @ self.shared_weight
        
        decoder_embed = self.expandor.model.shared(input['exp_decoder_ids']).detach().clone()
        outputs_exp = self.expandor(inputs_embeds = outputs_summ_logits,
                                    attention_mask = input['exp_attention_mask'],
                                    decoder_inputs_embeds = shift_embed_right(decoder_embed, 1, self.start_token_embed)
                                   )
        return outputs_exp.logits


class BartAutoEncoderDataset(Dataset):
    def __init__(self, article, tokenizer, shared, start_embed, device):
        input_len = len(article)
        inputs = tokenizer(article, 
                           return_tensors="pt", 
                           padding = "longest")
        dl_input = dict()
        dl_input['summ_attention_mask'] = inputs['attention_mask'][:input_len].to(device)
        dl_input['summ_input_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
        
        dl_input['exp_decoder_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
        dl_input['exp_attention_mask'] = inputs['attention_mask'][input_len:].to(device)
        
        outputs = inputs['input_ids'][:input_len].to(device)
        self.data = dl_input
        self.targets = outputs

    def __getitem__(self, index):    
        target_embed = self.targets[index]
        summ_input_ids = self.data['summ_input_ids'][index]
        summ_decoder_ids = self.data['summ_decoder_ids'][index]
        summ_attention_mask = self.data['summ_attention_mask'][index]
        exp_attention_mask = self.data['exp_attention_mask'][index]
        exp_decoder_ids = self.data['exp_decoder_ids'][index]
            
        return {'summ_input_ids': summ_input_ids,
                'summ_attention_mask': summ_attention_mask,
                'exp_attention_mask': exp_attention_mask,
                'exp_decoder_ids': exp_decoder_ids
               }, target_embed
    
    def __len__(self):
        return len(self.data)