Spaces:
Sleeping
Sleeping
import re | |
import torch | |
import numpy as np | |
from tqdm import trange | |
def select_sentences(paragraph, num_sentences): | |
sentences = re.split(r'(?<=[.!?])\s+', paragraph) | |
if num_sentences < 0: | |
last_sentences = sentences[num_sentences:] | |
elif num_sentences > 0: | |
last_sentences = sentences[:num_sentences] | |
selected = ' '.join(last_sentences) | |
return selected | |
def getitem(dataset, index): | |
inputs = dict() | |
inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]]) | |
inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]]) | |
return inputs | |
def reconstructionLoss(blocks, tokenizer, model, device): | |
scores = [] | |
model.eval() | |
inputDataset = tokenizer(blocks) | |
loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum') | |
for i in range(len(blocks)): | |
inputs = getitem(inputDataset, i) | |
dl_input = dict() | |
dl_input['summ_input_ids'] = inputs['input_ids'].to(device) | |
dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device) | |
dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device) | |
dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device) | |
labels = torch.flatten(inputs['input_ids']).to(device) | |
outputs = model(dl_input) | |
score = loss_fn(outputs.squeeze(), labels.squeeze()) | |
scores.append(score.item()) | |
return scores[0] | |
def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device): | |
model.eval() | |
splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device) | |
splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device) | |
splitScore = splitScore1 + splitScore2 | |
mergedParas = paragraph1 + '\n' + paragraph2 | |
mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device) | |
return splitScore - mergedScore | |
class Document(): | |
def __init__(self, text, tokenizer, | |
segsoft = '<block seg soft>', seghard = '<block seg hard>'): | |
''' | |
text: list of strings | |
index: float | |
''' | |
self.text = text | |
self.tokenizer = tokenizer | |
self.getSegString(segsoft, seghard) | |
self.segmentation = self.insertSeg(text) | |
def gettext(self): | |
return self.text | |
def getSegString(self, segsoft, seghard): | |
if (segsoft not in self.text) and (seghard not in self.text): | |
self.segStringSoft = segsoft | |
self.segStringHard = seghard | |
else: | |
raise ValueError('Segment string invalid, provide unique segment strings!') | |
return 0 | |
def insertSeg(self, article): | |
ansText = [] | |
ansSeg = [] | |
ansKey = [] | |
tokenizer = self.tokenizer | |
for key, content in article.items(): | |
if key in ['References', 'Reference']: | |
continue | |
for i in range(len(content)): | |
paragraph = content[i] | |
if i == len(content) - 1: | |
seg = self.segStringHard | |
ansText.append(paragraph) | |
ansSeg.append(seg) | |
ansKey.append(key) | |
break | |
follow = content[i+1] | |
twoPara = paragraph + ' ' + follow | |
if len(tokenizer(twoPara)['input_ids']) < 1024: | |
seg = self.segStringSoft | |
else: | |
seg = self.segStringHard | |
ansText.append(paragraph) | |
ansSeg.append(seg) | |
ansKey.append(key) | |
ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey} | |
return ans | |
def show(self): | |
for i in range(len(self.segmentation['text'])): | |
print(self.segmentation['key'][i]) | |
print(self.segmentation['text'][i]) | |
print(self.segmentation['seg'][i]) | |
print('\n') | |
def updateReconstrcutionLoss(self, lossScore, index, model, device): | |
model.eval() | |
lossScore.pop(index) | |
paragraph = self.segmentation['text'][index] | |
if index > 0: | |
if self.segmentation['seg'][index-1] == self.segStringHard: | |
lossScore[index-1] = np.inf | |
else: | |
before = self.segmentation['text'][index-1] | |
lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device) | |
if index < len(self.segmentation['text'])-1: | |
if self.segmentation['seg'][index] == self.segStringHard: | |
lossScore[index-1] = np.inf | |
else: | |
follow = self.segmentation['text'][index+1] | |
lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device) | |
return lossScore | |
def merge(self, minPage, maxPage, model, device): | |
model.eval() | |
if minPage > len(self.segmentation['text']): | |
return len(self.segmentation['text']) | |
lossScore = [] | |
for i in trange(len(self.segmentation['text']) - 1): | |
paragraph1 = self.segmentation['text'][i] | |
paragraph2 = self.segmentation['text'][i+1] | |
if self.segmentation['seg'][i] == self.segStringHard: | |
loss = np.inf | |
else: | |
loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device) | |
lossScore.append(loss) | |
while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf): | |
minScore = min(lossScore) | |
index = lossScore.index(minScore) | |
print('merging', index, 'and', index+1) | |
# update text | |
mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1] | |
self.segmentation['text'] = self.segmentation['text'][:index] + \ | |
[mergedParas] + \ | |
self.segmentation['text'][(index+2):] | |
# update key | |
self.segmentation['key'].pop(index+1) | |
# update segments | |
self.segmentation['seg'].pop(index) | |
paragraph = self.segmentation['text'][index] | |
if index > 0: | |
before = self.segmentation['text'][index-1] | |
twoPara1 = before + '\n' + paragraph | |
if len(self.tokenizer(twoPara1)['input_ids']) > 1024: | |
self.segmentation['seg'][index-1] = self.segStringHard | |
if index < len(self.segmentation['text'])-1: | |
follow = self.segmentation['text'][index+1] | |
twoPara2 = paragraph + '\n' + follow | |
if len(self.tokenizer(twoPara2)['input_ids']) > 1024: | |
self.segmentation['seg'][index] = self.segStringHard | |
# update loss | |
lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) | |
currentSegState = self.segmentation | |
currentSegScore = 0 | |
miniSegScore = 0 | |
while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf): | |
minScore = min(lossScore) | |
currentSegScore += minScore | |
# update text | |
index = lossScore.index(minScore) | |
mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1] | |
currentSegState['text'] = currentSegState['text'][:index] + \ | |
[mergedParas] + \ | |
currentSegState['text'][(index+2):] | |
# update key | |
currentSegState['key'].pop(index+1) | |
currentSegState['seg'].pop(index) | |
paragraph = currentSegState['text'][index] | |
if index > 0: | |
before = currentSegState['text'][index-1] | |
twoPara1 = before + '\n' + paragraph | |
if len(self.tokenizer(twoPara1)['input_ids']) > 1024: | |
print('warning') | |
currentSegState['seg'][index-1] = self.segStringHard | |
if index < len(currentSegState['text'])-1: | |
follow = currentSegState['text'][index+1] | |
twoPara2 = paragraph + '\n' + follow | |
if len(self.tokenizer(twoPara2)['input_ids']) > 1024: | |
print('warning') | |
currentSegState['seg'][index] = self.segStringHard | |
# update score | |
lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device) | |
if currentSegScore <= miniSegScore: | |
print('merging', index, 'and', index+1) | |
miniSegScore = currentSegScore | |
self.segmentation = currentSegState | |
return len(self.segmentation['text']) |