com3dian commited on
Commit
bd6566e
·
verified ·
1 Parent(s): 339798f

Create BartSE.py

Browse files
Files changed (1) hide show
  1. BartSE.py +93 -0
BartSE.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+
6
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
7
+ """
8
+ Shift input ids one token to the right.
9
+ """
10
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
11
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
12
+ shifted_input_ids[:, 0] = decoder_start_token_id
13
+
14
+ if pad_token_id is None:
15
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
16
+ # replace possible -100 values in labels by `pad_token_id`
17
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
18
+
19
+ return shifted_input_ids
20
+
21
+
22
+ def shift_embed_right(input_embed: torch.Tensor, pad_token_id: int, decoder_start_token_embed: torch.Tensor):
23
+ """
24
+ Shift input embeddings one column to the right.
25
+ """
26
+ shifted_input_embed = input_embed.new_zeros(input_embed.shape)
27
+ shifted_input_embed[:, 1:, :] = input_embed[:, :-1, :].clone()
28
+ shifted_input_embed[:, 0, :] = decoder_start_token_embed
29
+ if pad_token_id is None:
30
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
31
+
32
+ return shifted_input_embed
33
+
34
+
35
+ class BARTAutoEncoder(torch.nn.Module):
36
+ def __init__(self, summarizor, expandor, device):
37
+ super().__init__()
38
+ self.summarizor = summarizor.to(device)
39
+ self.expandor = expandor.to(device)
40
+ self.config = summarizor.config
41
+ self.start_token_embed = summarizor.model.shared(torch.tensor([2]).to(device)).detach().clone().requires_grad_(False)
42
+ self.shared_weight = self.summarizor.model.shared.weight.detach().clone().requires_grad_(False)
43
+ self.device = device
44
+
45
+ def forward(self, input):
46
+ summ_input = dict()
47
+ summ_input['input_ids'] = input['summ_input_ids'].to(self.device)
48
+ summ_input['attention_mask'] = input['summ_attention_mask'].to(self.device)
49
+
50
+ outputs_summ = self.summarizor(**summ_input)
51
+ outputs_summ_logits = outputs_summ.logits @ self.shared_weight
52
+
53
+ decoder_embed = self.expandor.model.shared(input['exp_decoder_ids']).detach().clone()
54
+ outputs_exp = self.expandor(inputs_embeds = outputs_summ_logits,
55
+ attention_mask = input['exp_attention_mask'],
56
+ decoder_inputs_embeds = shift_embed_right(decoder_embed, 1, self.start_token_embed)
57
+ )
58
+ return outputs_exp.logits
59
+
60
+
61
+ class BartAutoEncoderDataset(Dataset):
62
+ def __init__(self, article, tokenizer, shared, start_embed, device):
63
+ input_len = len(article)
64
+ inputs = tokenizer(article,
65
+ return_tensors="pt",
66
+ padding = "longest")
67
+ dl_input = dict()
68
+ dl_input['summ_attention_mask'] = inputs['attention_mask'][:input_len].to(device)
69
+ dl_input['summ_input_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
70
+
71
+ dl_input['exp_decoder_ids'] = inputs['input_ids'][:input_len].detach().clone().to(device)
72
+ dl_input['exp_attention_mask'] = inputs['attention_mask'][input_len:].to(device)
73
+
74
+ outputs = inputs['input_ids'][:input_len].to(device)
75
+ self.data = dl_input
76
+ self.targets = outputs
77
+
78
+ def __getitem__(self, index):
79
+ target_embed = self.targets[index]
80
+ summ_input_ids = self.data['summ_input_ids'][index]
81
+ summ_decoder_ids = self.data['summ_decoder_ids'][index]
82
+ summ_attention_mask = self.data['summ_attention_mask'][index]
83
+ exp_attention_mask = self.data['exp_attention_mask'][index]
84
+ exp_decoder_ids = self.data['exp_decoder_ids'][index]
85
+
86
+ return {'summ_input_ids': summ_input_ids,
87
+ 'summ_attention_mask': summ_attention_mask,
88
+ 'exp_attention_mask': exp_attention_mask,
89
+ 'exp_decoder_ids': exp_decoder_ids
90
+ }, target_embed
91
+
92
+ def __len__(self):
93
+ return len(self.data)