Spaces:
Sleeping
Sleeping
| import re | |
| import torch | |
| import numpy as np | |
| from tqdm import trange | |
| def select_sentences(paragraph, num_sentences): | |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) | |
| if num_sentences < 0: | |
| last_sentences = sentences[num_sentences:] | |
| elif num_sentences > 0: | |
| last_sentences = sentences[:num_sentences] | |
| selected = ' '.join(last_sentences) | |
| return selected | |
| def getitem(dataset, index): | |
| inputs = dict() | |
| inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]]) | |
| inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]]) | |
| return inputs | |
| def reconstructionLoss(blocks, tokenizer, model, device): | |
| scores = [] | |
| model.eval() | |
| inputDataset = tokenizer(blocks) | |
| loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum') | |
| for i in range(len(blocks)): | |
| inputs = getitem(inputDataset, i) | |
| dl_input = dict() | |
| dl_input['summ_input_ids'] = inputs['input_ids'].to(device) | |
| dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device) | |
| dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device) | |
| dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device) | |
| labels = torch.flatten(inputs['input_ids']).to(device) | |
| outputs = model(dl_input) | |
| score = loss_fn(outputs.squeeze(), labels.squeeze()) | |
| scores.append(score.item()) | |
| return scores[0] | |
| def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device): | |
| model.eval() | |
| splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device) | |
| splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device) | |
| splitScore = splitScore1 + splitScore2 | |
| mergedParas = paragraph1 + '\n' + paragraph2 | |
| mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device) | |
| return splitScore - mergedScore | |
| class Document(): | |
| def __init__(self, text, tokenizer, | |
| segsoft = '<block seg soft>', seghard = '<block seg hard>'): | |
| ''' | |
| text: list of strings | |
| index: float | |
| ''' | |
| self.text = text | |
| self.tokenizer = tokenizer | |
| self.getSegString(segsoft, seghard) | |
| self.segmentation = self.insertSeg(text) | |
| def gettext(self): | |
| return self.text | |
| def getSegString(self, segsoft, seghard): | |
| if (segsoft not in self.text) and (seghard not in self.text): | |
| self.segStringSoft = segsoft | |
| self.segStringHard = seghard | |
| else: | |
| raise ValueError('Segment string invalid, provide unique segment strings!') | |
| return 0 | |
| def insertSeg(self, article): | |
| ansText = [] | |
| ansSeg = [] | |
| ansKey = [] | |
| tokenizer = self.tokenizer | |
| for key, content in article.items(): | |
| if key in ['References', 'Reference']: | |
| continue | |
| for i in range(len(content)): | |
| paragraph = content[i] | |
| if i == len(content) - 1: | |
| seg = self.segStringHard | |
| ansText.append(paragraph) | |
| ansSeg.append(seg) | |
| ansKey.append(key) | |
| break | |
| follow = content[i+1] | |
| twoPara = paragraph + ' ' + follow | |
| if len(tokenizer(twoPara)['input_ids']) < 1024: | |
| seg = self.segStringSoft | |
| else: | |
| seg = self.segStringHard | |
| ansText.append(paragraph) | |
| ansSeg.append(seg) | |
| ansKey.append(key) | |
| ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey} | |
| return ans | |
| def show(self): | |
| for i in range(len(self.segmentation['text'])): | |
| print(self.segmentation['key'][i]) | |
| print(self.segmentation['text'][i]) | |
| print(self.segmentation['seg'][i]) | |
| print('\n') | |
| def updateReconstrcutionLoss(self, lossScore, index, model, device): | |
| model.eval() | |
| lossScore.pop(index) | |
| paragraph = self.segmentation['text'][index] | |
| if index > 0: | |
| if self.segmentation['seg'][index-1] == self.segStringHard: | |
| lossScore[index-1] = np.inf | |
| else: | |
| before = self.segmentation['text'][index-1] | |
| lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device) | |
| if index < len(self.segmentation['text'])-1: | |
| if self.segmentation['seg'][index] == self.segStringHard: | |
| lossScore[index-1] = np.inf | |
| else: | |
| follow = self.segmentation['text'][index+1] | |
| lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device) | |
| return lossScore | |
| def merge(self, minPage, maxPage, model, device): | |
| model.eval() | |
| if minPage > len(self.segmentation['text']): | |
| return len(self.segmentation['text']) | |
| lossScore = [] | |
| for i in trange(len(self.segmentation['text']) - 1): | |
| paragraph1 = self.segmentation['text'][i] | |
| paragraph2 = self.segmentation['text'][i+1] | |
| if self.segmentation['seg'][i] == self.segStringHard: | |
| loss = np.inf | |
| else: | |
| loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device) | |
| lossScore.append(loss) | |
| while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf): | |
| minScore = min(lossScore) | |
| index = lossScore.index(minScore) | |
| print('merging', index, 'and', index+1) | |
| # update text | |
| mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1] | |
| self.segmentation['text'] = self.segmentation['text'][:index] + \ | |
| [mergedParas] + \ | |
| self.segmentation['text'][(index+2):] | |
| # update key | |
| self.segmentation['key'].pop(index+1) | |
| # update segments | |
| self.segmentation['seg'].pop(index) | |
| paragraph = self.segmentation['text'][index] | |
| if index > 0: | |
| before = self.segmentation['text'][index-1] | |
| twoPara1 = before + '\n' + paragraph | |
| if len(self.tokenizer(twoPara1)['input_ids']) > 1024: | |
| self.segmentation['seg'][index-1] = self.segStringHard | |
| if index < len(self.segmentation['text'])-1: | |
| follow = self.segmentation['text'][index+1] | |
| twoPara2 = paragraph + '\n' + follow | |
| if len(self.tokenizer(twoPara2)['input_ids']) > 1024: | |
| self.segmentation['seg'][index] = self.segStringHard | |
| # update loss | |
| lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) | |
| currentSegState = self.segmentation | |
| currentSegScore = 0 | |
| miniSegScore = 0 | |
| while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf): | |
| minScore = min(lossScore) | |
| currentSegScore += minScore | |
| # update text | |
| index = lossScore.index(minScore) | |
| mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1] | |
| currentSegState['text'] = currentSegState['text'][:index] + \ | |
| [mergedParas] + \ | |
| currentSegState['text'][(index+2):] | |
| # update key | |
| currentSegState['key'].pop(index+1) | |
| currentSegState['seg'].pop(index) | |
| paragraph = currentSegState['text'][index] | |
| if index > 0: | |
| before = currentSegState['text'][index-1] | |
| twoPara1 = before + '\n' + paragraph | |
| if len(self.tokenizer(twoPara1)['input_ids']) > 1024: | |
| print('warning') | |
| currentSegState['seg'][index-1] = self.segStringHard | |
| if index < len(currentSegState['text'])-1: | |
| follow = currentSegState['text'][index+1] | |
| twoPara2 = paragraph + '\n' + follow | |
| if len(self.tokenizer(twoPara2)['input_ids']) > 1024: | |
| print('warning') | |
| currentSegState['seg'][index] = self.segStringHard | |
| # update score | |
| lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) | |
| if currentSegScore <= miniSegScore: | |
| print('merging', index, 'and', index+1) | |
| miniSegScore = currentSegScore | |
| self.segmentation = currentSegState | |
| return len(self.segmentation['text']) |