import re
import torch
import numpy as np
from tqdm import trange


def select_sentences(paragraph, num_sentences):
    sentences = re.split(r'(?<=[.!?])\s+', paragraph)
    if num_sentences < 0:
        last_sentences = sentences[num_sentences:]
    elif num_sentences > 0:
        last_sentences = sentences[:num_sentences]
    selected = ' '.join(last_sentences)
    return selected

def getitem(dataset, index):
    inputs = dict()
    inputs['input_ids'] = torch.LongTensor([dataset['input_ids'][index]])
    inputs['attention_mask'] = torch.LongTensor([dataset['attention_mask'][index]])

    return inputs
    
def reconstructionLoss(blocks, tokenizer, model, device):
    scores = []
    model.eval()
    inputDataset = tokenizer(blocks)
    loss_fn = torch.nn.CrossEntropyLoss(reduction = 'sum')
    for i in range(len(blocks)):
        inputs = getitem(inputDataset, i)
        dl_input = dict()
        dl_input['summ_input_ids'] = inputs['input_ids'].to(device)
        dl_input['summ_attention_mask'] = inputs['attention_mask'].to(device)
        dl_input['exp_decoder_ids'] = inputs['input_ids'].to(device)
        dl_input['exp_attention_mask'] = inputs['attention_mask'].to(device)
        
        labels = torch.flatten(inputs['input_ids']).to(device)
        outputs = model(dl_input)
        score = loss_fn(outputs.squeeze(), labels.squeeze())
        scores.append(score.item())
    return scores[0]

def paragraphLoss(paragraph1, paragraph2, tokenizer, model, device):
    model.eval()
    splitScore1 = reconstructionLoss([paragraph1], tokenizer, model, device)
    splitScore2 = reconstructionLoss([paragraph2], tokenizer, model, device)
    splitScore = splitScore1 + splitScore2
    mergedParas = paragraph1 + '\n' + paragraph2
    mergedScore = reconstructionLoss([mergedParas], tokenizer, model, device)
    return splitScore - mergedScore

class Document():
    def __init__(self, text, tokenizer, 
                 segsoft = '<block seg soft>', seghard = '<block seg hard>'):
        '''
        text: list of strings
        index: float
        '''
        self.text = text
        self.tokenizer = tokenizer
        self.getSegString(segsoft, seghard)
        self.segmentation = self.insertSeg(text)

    def gettext(self):
        return self.text
    
    def getSegString(self, segsoft, seghard):
        if (segsoft not in self.text) and (seghard not in self.text):
            self.segStringSoft = segsoft
            self.segStringHard = seghard
        else:
            raise ValueError('Segment string invalid, provide unique segment strings!')
        return 0

    def insertSeg(self, article):
        ansText = []
        ansSeg = []
        ansKey = []
        tokenizer = self.tokenizer
        for key, content in article.items():
            if key in ['References', 'Reference']:
                continue
            for i in range(len(content)):
                paragraph = content[i]
                if i == len(content) - 1:
                    seg = self.segStringHard
                    ansText.append(paragraph)
                    ansSeg.append(seg)
                    ansKey.append(key)
                    break
                
                follow = content[i+1]
                twoPara = paragraph + ' ' + follow
                if len(tokenizer(twoPara)['input_ids']) < 1024:
                    seg = self.segStringSoft
                else:
                    seg = self.segStringHard
                ansText.append(paragraph)
                ansSeg.append(seg)
                ansKey.append(key)
        ans = {'text': ansText, 'seg': ansSeg, 'key':ansKey}
        return ans
    
    def show(self):
        for i in range(len(self.segmentation['text'])):
            print(self.segmentation['key'][i])
            print(self.segmentation['text'][i])
            print(self.segmentation['seg'][i])
            print('\n')
            
    def updateReconstrcutionLoss(self, lossScore, index, model, device):
        model.eval()
        lossScore.pop(index)
        paragraph = self.segmentation['text'][index]
        if index > 0:
            if self.segmentation['seg'][index-1] == self.segStringHard:
                lossScore[index-1] = np.inf
            else:
                before = self.segmentation['text'][index-1]
                lossScore[index-1] = paragraphLoss(before, paragraph, self.tokenizer, model, device)
        if index < len(self.segmentation['text'])-1:
            if self.segmentation['seg'][index] == self.segStringHard:
                lossScore[index-1] = np.inf
            else:
                follow = self.segmentation['text'][index+1]
                lossScore[index] = paragraphLoss(paragraph, follow, self.tokenizer, model, device)        
        
        return lossScore
        
    def merge(self, minPage, maxPage, model, device):
        model.eval()
        if minPage > len(self.segmentation['text']):
            return len(self.segmentation['text'])
        
        lossScore = []
        for i in trange(len(self.segmentation['text']) - 1):
            paragraph1 = self.segmentation['text'][i]
            paragraph2 = self.segmentation['text'][i+1]
            if self.segmentation['seg'][i] == self.segStringHard:
                loss = np.inf
            else:
                loss = paragraphLoss(paragraph1, paragraph2, self.tokenizer, model, device)
            lossScore.append(loss)
        
        while(len(self.segmentation['text']) > maxPage and min(lossScore) < np.inf):
            minScore = min(lossScore)
            index = lossScore.index(minScore)
            print('merging', index, 'and', index+1)
            # update text
            mergedParas = self.segmentation['text'][index] + '\n' + self.segmentation['text'][index+1]
            self.segmentation['text'] = self.segmentation['text'][:index] + \
                                        [mergedParas] + \
                                        self.segmentation['text'][(index+2):]
            # update key
            self.segmentation['key'].pop(index+1)
            
            # update segments
            self.segmentation['seg'].pop(index)
            paragraph = self.segmentation['text'][index]
            if index > 0:
                before = self.segmentation['text'][index-1]
                twoPara1 = before + '\n' + paragraph
                if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
                    self.segmentation['seg'][index-1] = self.segStringHard
            
            if index < len(self.segmentation['text'])-1:
                follow = self.segmentation['text'][index+1]
                twoPara2 = paragraph + '\n' + follow
                if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
                    self.segmentation['seg'][index] = self.segStringHard

            # update loss
            lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
            
        currentSegState = self.segmentation
        currentSegScore = 0
        miniSegScore = 0
        
        
        while(len(currentSegState['text']) > minPage and min(lossScore) < np.inf):
            minScore = min(lossScore)
            currentSegScore += minScore
            # update text
            index = lossScore.index(minScore)
            mergedParas = currentSegState['text'][index] + '\n' + currentSegState['text'][index+1]
            currentSegState['text'] = currentSegState['text'][:index] + \
                                      [mergedParas] + \
                                      currentSegState['text'][(index+2):]
            # update key
            currentSegState['key'].pop(index+1)
            currentSegState['seg'].pop(index)
            paragraph = currentSegState['text'][index]
            if index > 0:
                before = currentSegState['text'][index-1]
                twoPara1 = before + '\n' + paragraph
                if len(self.tokenizer(twoPara1)['input_ids']) > 1024:
                    print('warning')
                    currentSegState['seg'][index-1] = self.segStringHard
            if index < len(currentSegState['text'])-1:
                follow = currentSegState['text'][index+1]
                twoPara2 = paragraph + '\n' + follow
                if len(self.tokenizer(twoPara2)['input_ids']) > 1024:
                    print('warning')
                    currentSegState['seg'][index] = self.segStringHard
            # update score
            lossScore = self.updateReconstrcutionLoss(lossScore, index, model, device)
            if currentSegScore <= miniSegScore:
                print('merging', index, 'and', index+1)
                miniSegScore = currentSegScore
                self.segmentation = currentSegState
        return len(self.segmentation['text'])