import re import torch import numpy as np from tqdm import trange def select_sentences(paragraph, num_sentences): sentences = re.split(r'(?<=[.!?])\s+', paragraph) if num_sentences < 0: last_sentences = sentences[num_sentences:] elif num_sentences > 0: last_sentences = sentences[:num_sentences] selected = ' '.join(last_sentences) return selected def getitem(dataset, index): inputs = dict() inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]]) inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]]) return inputs def reconstructionLoss(blocks, tokenizer, model, device): scores = [] model.eval() inputDataset = tokenizer(blocks) loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum') for i in range(len(blocks)): inputs = getitem(inputDataset, i) dl_input = dict() dl_input['summ_input_ids'] = inputs['input_ids'].to(device) dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device) dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device) dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device) labels = torch.flatten(inputs['input_ids']).to(device) outputs = model(dl_input) score = loss_fn(outputs.squeeze(), labels.squeeze()) scores.append(score.item()) return scores[0] def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device): model.eval() splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device) splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device) splitScore = splitScore1 + splitScore2 mergedParas = paragraph1 + '\n' + paragraph2 mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device) return splitScore - mergedScore class Document(): def __init__(self, text, tokenizer, segsoft = '', seghard = ''): ''' text: list of strings index: float ''' self.text = text self.tokenizer = tokenizer self.getSegString(segsoft, seghard) self.segmentation = self.insertSeg(text) def gettext(self): return self.text def getSegString(self, segsoft, seghard): if (segsoft not in self.text) and (seghard not in self.text): self.segStringSoft = segsoft self.segStringHard = seghard else: raise ValueError('Segment string invalid, provide unique segment strings!') return 0 def insertSeg(self, article): ansText = [] ansSeg = [] ansKey = [] tokenizer = self.tokenizer for key, content in article.items(): if key in ['References', 'Reference']: continue for i in range(len(content)): paragraph = content[i] if i == len(content) - 1: seg = self.segStringHard ansText.append(paragraph) ansSeg.append(seg) ansKey.append(key) break follow = content[i+1] twoPara = paragraph + ' ' + follow if len(tokenizer(twoPara)['input_ids']) < 1024: seg = self.segStringSoft else: seg = self.segStringHard ansText.append(paragraph) ansSeg.append(seg) ansKey.append(key) ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey} return ans def show(self): for i in range(len(self.segmentation['text'])): print(self.segmentation['key'][i]) print(self.segmentation['text'][i]) print(self.segmentation['seg'][i]) print('\n') def updateReconstrcutionLoss(self, lossScore, index, model, device): model.eval() lossScore.pop(index) paragraph = self.segmentation['text'][index] if index > 0: if self.segmentation['seg'][index-1] == self.segStringHard: lossScore[index-1] = np.inf else: before = self.segmentation['text'][index-1] lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device) if index < len(self.segmentation['text'])-1: if self.segmentation['seg'][index] == self.segStringHard: lossScore[index-1] = np.inf else: follow = self.segmentation['text'][index+1] lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device) return lossScore def merge(self, minPage, maxPage, model, device): model.eval() if minPage > len(self.segmentation['text']): return len(self.segmentation['text']) lossScore = [] for i in trange(len(self.segmentation['text']) - 1): paragraph1 = self.segmentation['text'][i] paragraph2 = self.segmentation['text'][i+1] if self.segmentation['seg'][i] == self.segStringHard: loss = np.inf else: loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device) lossScore.append(loss) while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf): minScore = min(lossScore) index = lossScore.index(minScore) print('merging', index, 'and', index+1) # update text mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1] self.segmentation['text'] = self.segmentation['text'][:index] + \ [mergedParas] + \ self.segmentation['text'][(index+2):] # update key self.segmentation['key'].pop(index+1) # update segments self.segmentation['seg'].pop(index) paragraph = self.segmentation['text'][index] if index > 0: before = self.segmentation['text'][index-1] twoPara1 = before + '\n' + paragraph if len(self.tokenizer(twoPara1)['input_ids']) > 1024: self.segmentation['seg'][index-1] = self.segStringHard if index < len(self.segmentation['text'])-1: follow = self.segmentation['text'][index+1] twoPara2 = paragraph + '\n' + follow if len(self.tokenizer(twoPara2)['input_ids']) > 1024: self.segmentation['seg'][index] = self.segStringHard # update loss lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) currentSegState = self.segmentation currentSegScore = 0 miniSegScore = 0 while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf): minScore = min(lossScore) currentSegScore += minScore # update text index = lossScore.index(minScore) mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1] currentSegState['text'] = currentSegState['text'][:index] + \ [mergedParas] + \ currentSegState['text'][(index+2):] # update key currentSegState['key'].pop(index+1) currentSegState['seg'].pop(index) paragraph = currentSegState['text'][index] if index > 0: before = currentSegState['text'][index-1] twoPara1 = before + '\n' + paragraph if len(self.tokenizer(twoPara1)['input_ids']) > 1024: print('warning') currentSegState['seg'][index-1] = self.segStringHard if index < len(currentSegState['text'])-1: follow = currentSegState['text'][index+1] twoPara2 = paragraph + '\n' + follow if len(self.tokenizer(twoPara2)['input_ids']) > 1024: print('warning') currentSegState['seg'][index] = self.segStringHard # update score lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) if currentSegScore <= miniSegScore: print('merging', index, 'and', index+1) miniSegScore = currentSegScore self.segmentation = currentSegState return len(self.segmentation['text'])