doc-to-slides / document.py
com3dian's picture
Update document.py
aa71eb6 verified
import re
import torch
import numpy as np
from tqdm import trange
def select_sentences(paragraph, num_sentences):
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
if num_sentences < 0:
last_sentences = sentences[num_sentences:]
elif num_sentences > 0:
last_sentences = sentences[:num_sentences]
selected = ' '.join(last_sentences)
return selected
def getitem(dataset, index):
inputs = dict()
inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]])
inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]])
return inputs
def reconstructionLoss(blocks, tokenizer, model, device):
scores = []
model.eval()
inputDataset = tokenizer(blocks)
loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum')
for i in range(len(blocks)):
inputs = getitem(inputDataset, i)
dl_input = dict()
dl_input['summ_input_ids'] = inputs['input_ids'].to(device)
dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device)
dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device)
dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device)
labels = torch.flatten(inputs['input_ids']).to(device)
outputs = model(dl_input)
score = loss_fn(outputs.squeeze(), labels.squeeze())
scores.append(score.item())
return scores[0]
def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device):
model.eval()
splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device)
splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device)
splitScore = splitScore1 + splitScore2
mergedParas = paragraph1 + '\n' + paragraph2
mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device)
return splitScore - mergedScore
class Document():
def __init__(self, text, tokenizer,
segsoft = '<block seg soft>', seghard = '<block seg hard>'):
'''
text: list of strings
index: float
'''
self.text = text
self.tokenizer = tokenizer
self.getSegString(segsoft, seghard)
self.segmentation = self.insertSeg(text)
def gettext(self):
return self.text
def getSegString(self, segsoft, seghard):
if (segsoft not in self.text) and (seghard not in self.text):
self.segStringSoft = segsoft
self.segStringHard = seghard
else:
raise ValueError('Segment string invalid, provide unique segment strings!')
return 0
def insertSeg(self, article):
ansText = []
ansSeg = []
ansKey = []
tokenizer = self.tokenizer
for key, content in article.items():
if key in ['References', 'Reference']:
continue
for i in range(len(content)):
paragraph = content[i]
if i == len(content) - 1:
seg = self.segStringHard
ansText.append(paragraph)
ansSeg.append(seg)
ansKey.append(key)
break
follow = content[i+1]
twoPara = paragraph + ' ' + follow
if len(tokenizer(twoPara)['input_ids']) < 1024:
seg = self.segStringSoft
else:
seg = self.segStringHard
ansText.append(paragraph)
ansSeg.append(seg)
ansKey.append(key)
ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey}
return ans
def show(self):
for i in range(len(self.segmentation['text'])):
print(self.segmentation['key'][i])
print(self.segmentation['text'][i])
print(self.segmentation['seg'][i])
print('\n')
def updateReconstrcutionLoss(self, lossScore, index, model, device):
model.eval()
lossScore.pop(index)
paragraph = self.segmentation['text'][index]
if index > 0:
if self.segmentation['seg'][index-1] == self.segStringHard:
lossScore[index-1] = np.inf
else:
before = self.segmentation['text'][index-1]
lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device)
if index < len(self.segmentation['text'])-1:
if self.segmentation['seg'][index] == self.segStringHard:
lossScore[index-1] = np.inf
else:
follow = self.segmentation['text'][index+1]
lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device)
return lossScore
def merge(self, minPage, maxPage, model, device):
model.eval()
if minPage > len(self.segmentation['text']):
return len(self.segmentation['text'])
lossScore = []
for i in trange(len(self.segmentation['text']) - 1):
paragraph1 = self.segmentation['text'][i]
paragraph2 = self.segmentation['text'][i+1]
if self.segmentation['seg'][i] == self.segStringHard:
loss = np.inf
else:
loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device)
lossScore.append(loss)
while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf):
minScore = min(lossScore)
index = lossScore.index(minScore)
print('merging', index, 'and', index+1)
# update text
mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1]
self.segmentation['text'] = self.segmentation['text'][:index] + \
[mergedParas] + \
self.segmentation['text'][(index+2):]
# update key
self.segmentation['key'].pop(index+1)
# update segments
self.segmentation['seg'].pop(index)
paragraph = self.segmentation['text'][index]
if index > 0:
before = self.segmentation['text'][index-1]
twoPara1 = before + '\n' + paragraph
if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
self.segmentation['seg'][index-1] = self.segStringHard
if index < len(self.segmentation['text'])-1:
follow = self.segmentation['text'][index+1]
twoPara2 = paragraph + '\n' + follow
if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
self.segmentation['seg'][index] = self.segStringHard
# update loss
lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
currentSegState = self.segmentation
currentSegScore = 0
miniSegScore = 0
while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf):
minScore = min(lossScore)
currentSegScore += minScore
# update text
index = lossScore.index(minScore)
mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1]
currentSegState['text'] = currentSegState['text'][:index] + \
[mergedParas] + \
currentSegState['text'][(index+2):]
# update key
currentSegState['key'].pop(index+1)
currentSegState['seg'].pop(index)
paragraph = currentSegState['text'][index]
if index > 0:
before = currentSegState['text'][index-1]
twoPara1 = before + '\n' + paragraph
if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
print('warning')
currentSegState['seg'][index-1] = self.segStringHard
if index < len(currentSegState['text'])-1:
follow = currentSegState['text'][index+1]
twoPara2 = paragraph + '\n' + follow
if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
print('warning')
currentSegState['seg'][index] = self.segStringHard
# update score
lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
if currentSegScore <= miniSegScore:
print('merging', index, 'and', index+1)
miniSegScore = currentSegScore
self.segmentation = currentSegState
return len(self.segmentation['text'])