com3dian commited on
Commit
2b3c5ca
·
verified ·
1 Parent(s): 7ad4088

Create document.py

Browse files
Files changed (1) hide show
  1. document.py +210 -0
document.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import trange
5
+
6
+
7
+ def select_sentences(paragraph, num_sentences):
8
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
9
+ if num_sentences < 0:
10
+ last_sentences = sentences[num_sentences:]
11
+ elif num_sentences > 0:
12
+ last_sentences = sentences[:num_sentences]
13
+ selected = ' '.join(last_sentences)
14
+ return selected
15
+
16
+ def getitem(dataset, index):
17
+ inputs = dict()
18
+ inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]])
19
+ inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]])
20
+
21
+ return inputs
22
+
23
+ def reconstructionLoss(blocks, tokenizer, model, device):
24
+ scores = []
25
+ model.eval()
26
+ inputDataset = tokenizer(blocks)
27
+ loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum')
28
+ for i in range(len(blocks)):
29
+ inputs = getitem(inputDataset, i)
30
+ dl_input = dict()
31
+ dl_input['summ_input_ids'] = inputs['input_ids'].to(device)
32
+ dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device)
33
+ dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device)
34
+ dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device)
35
+
36
+ labels = torch.flatten(inputs['input_ids']).to(device)
37
+ outputs = model(dl_input)
38
+ score = loss_fn(outputs.squeeze(), labels.squeeze())
39
+ scores.append(score.item())
40
+ return scores[0]
41
+
42
+ def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device):
43
+ model.eval()
44
+ splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device)
45
+ splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device)
46
+ splitScore = splitScore1 + splitScore2
47
+ mergedParas = paragraph1 + '\n' + paragraph2
48
+ mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device)
49
+ return splitScore - mergedScore
50
+
51
+ class Document():
52
+ def __init__(self, text, tokenizer,
53
+ segsoft = '<block seg soft>', seghard = '<block seg hard>'):
54
+ '''
55
+ text: list of strings
56
+ index: float
57
+ '''
58
+ self.text = text
59
+ self.tokenizer = tokenizer
60
+ self.getSegString(segsoft, seghard)
61
+ self.segmentation = self.insertSeg(text)
62
+
63
+ def gettext(self):
64
+ return self.text
65
+
66
+ def getSegString(self, segsoft, seghard):
67
+ if (segsoft not in self.text) and (seghard not in self.text):
68
+ self.segStringSoft = segsoft
69
+ self.segStringHard = seghard
70
+ else:
71
+ raise ValueError('Segment string invalid, provide unique segment strings!')
72
+ return 0
73
+
74
+ def insertSeg(self, article):
75
+ ansText = []
76
+ ansSeg = []
77
+ ansKey = []
78
+ tokenizer = self.tokenizer
79
+ for key, content in article.items():
80
+ if key in ['References', 'Reference']:
81
+ continue
82
+ for i in range(len(content)):
83
+ paragraph = content[i]
84
+ if i == len(content) - 1:
85
+ seg = self.segStringHard
86
+ ansText.append(paragraph)
87
+ ansSeg.append(seg)
88
+ ansKey.append(key)
89
+ break
90
+
91
+ follow = content[i+1]
92
+ twoPara = paragraph + ' ' + follow
93
+ if len(tokenizer(twoPara)['input_ids']) < 1024:
94
+ seg = self.segStringSoft
95
+ else:
96
+ seg = self.segStringHard
97
+ ansText.append(paragraph)
98
+ ansSeg.append(seg)
99
+ ansKey.append(key)
100
+ ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey}
101
+ return ans
102
+
103
+ def show(self):
104
+ for i in range(len(self.segmentation['text'])):
105
+ print(self.segmentation['key'][i])
106
+ print(self.segmentation['text'][i])
107
+ print(self.segmentation['seg'][i])
108
+ print('\n')
109
+
110
+ def updateReconstrcutionLoss(self, lossScore, index, model, device):
111
+ model.eval()
112
+ lossScore.pop(index)
113
+ paragraph = self.segmentation['text'][index]
114
+ if index > 0:
115
+ if self.segmentation['seg'][index-1] == self.segStringHard:
116
+ lossScore[index-1] = np.inf
117
+ else:
118
+ before = self.segmentation['text'][index-1]
119
+ lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device)
120
+ if index < len(self.segmentation['text'])-1:
121
+ if self.segmentation['seg'][index] == self.segStringHard:
122
+ lossScore[index-1] = np.inf
123
+ else:
124
+ follow = self.segmentation['text'][index+1]
125
+ lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device)
126
+
127
+ return lossScore
128
+
129
+ def merge(self, minPage, maxPage, model, device):
130
+ model.eval()
131
+ if minPage > len(self.segmentation['text']):
132
+ return len(self.segmentation['text'])
133
+
134
+ lossScore = []
135
+ for i in trange(len(self.segmentation['text']) - 1):
136
+ paragraph1 = self.segmentation['text'][i]
137
+ paragraph2 = self.segmentation['text'][i+1]
138
+ if self.segmentation['seg'][i] == self.segStringHard:
139
+ loss = np.inf
140
+ else:
141
+ loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device)
142
+ lossScore.append(loss)
143
+
144
+ while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf):
145
+ minScore = min(lossScore)
146
+ index = lossScore.index(minScore)
147
+ print('merging', index, 'and', index+1)
148
+ # update text
149
+ mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1]
150
+ self.segmentation['text'] = self.segmentation['text'][:index] + \
151
+ [mergedParas] + \
152
+ self.segmentation['text'][(index+2):]
153
+ # update key
154
+ self.segmentation['key'].pop(index+1)
155
+
156
+ # update segments
157
+ self.segmentation['seg'].pop(index)
158
+ paragraph = self.segmentation['text'][index]
159
+ if index > 0:
160
+ before = self.segmentation['text'][index-1]
161
+ twoPara1 = before + '\n' + paragraph
162
+ if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
163
+ self.segmentation['seg'][index-1] = self.segStringHard
164
+
165
+ if index < len(self.segmentation['text'])-1:
166
+ follow = self.segmentation['text'][index+1]
167
+ twoPara2 = paragraph + '\n' + follow
168
+ if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
169
+ self.segmentation['seg'][index] = self.segStringHard
170
+
171
+ # update loss
172
+ lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
173
+
174
+ currentSegState = self.segmentation
175
+ currentSegScore = 0
176
+ miniSegScore = 0
177
+
178
+
179
+ while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf):
180
+ minScore = min(lossScore)
181
+ currentSegScore += minScore
182
+ # update text
183
+ index = lossScore.index(minScore)
184
+ mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1]
185
+ currentSegState['text'] = currentSegState['text'][:index] + \
186
+ [mergedParas] + \
187
+ currentSegState['text'][(index+2):]
188
+ # update key
189
+ currentSegState['key'].pop(index+1)
190
+ currentSegState['seg'].pop(index)
191
+ paragraph = currentSegState['text'][index]
192
+ if index > 0:
193
+ before = currentSegState['text'][index-1]
194
+ twoPara1 = before + '\n' + paragraph
195
+ if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
196
+ print('warning')
197
+ currentSegState['seg'][index-1] = self.segStringHard
198
+ if index < len(currentSegState['text'])-1:
199
+ follow = currentSegState['text'][index+1]
200
+ twoPara2 = paragraph + '\n' + follow
201
+ if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
202
+ print('warning')
203
+ currentSegState['seg'][index] = self.segStringHard
204
+ # update score
205
+ lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
206
+ if currentSegScore <= miniSegScore:
207
+ print('merging', index, 'and', index+1)
208
+ miniSegScore = currentSegScore
209
+ self.segmentation = currentSegState
210
+ return len(self.segmentation['text'])