# coding:utf-8
"""
Filename: mt5.py
Author: @DvdNss
Created on 12/30/2021
"""

from typing import List

from pytorch_lightning import LightningModule
from transformers import MT5ForConditionalGeneration, AutoTokenizer


class MT5(LightningModule):
    """
    Google MT5 transformer class.
    """

    def __init__(self, model_name_or_path: str = None):
        """
        Initialize module.
        :param model_name_or_path: model name
        """

        super().__init__()

        # Load model and tokenizer
        self.save_hyperparameters()
        self.model = MT5ForConditionalGeneration.from_pretrained(
            model_name_or_path) if model_name_or_path is not None else None
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
                                                       use_fast=True) if model_name_or_path is not None else None

    def forward(self, **inputs):
        """
        Forward inputs.
        :param inputs: dictionary of inputs (input_ids, attention_mask, labels)
        """

        return self.model(**inputs)

    def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
        """
        Question answering prediction.
        :param batch: batch of dict {question: q, context: c}
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"question: {context['question']}  context: {context['context']}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
        """
        Question generation prediction.
        :param batch: batch of context with highlighted elements
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"generate: {context}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def ae(self, batch: List[str], max_length: int = 512, **kwargs):
        """
        Answer extraction prediction.
        :param batch: list of context
        :param max_length: max length of output
        """

        # Transform inputs
        inputs = [f"extract: {context}" for context in batch]

        # Predict
        outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)

        return outputs

    def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
        """
        Answer extraction + question generation + question answering.
        :param batch: list of context
        :param max_length: max length of outputs
        """

        # Build output dict
        dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}

        # Iterate over context
        for context in batch:
            answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
            answers = answers.split('<sep>')
            answers = [ans.strip() for ans in answers if ans != ' ']
            dict_batch['answers'].append(answers)
            for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
            questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
            dict_batch['questions'].append(questions)
            new_answers = self.qa([{'context': context, 'question': question} for question in questions],
                                  max_length=max_length, **kwargs)
            dict_batch['answers_bis'].append(new_answers)
        return dict_batch

    def predict(self, inputs, max_length, **kwargs):
        """
        Inference processing.
        :param inputs: list of inputs
        :param max_length: max_length of outputs
        """

        # Tokenize inputs
        inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
                                return_tensors="pt")

        # Retrieve input_ids and attention_mask
        input_ids = inputs.input_ids.to(self.model.device)
        attention_mask = inputs.attention_mask.to(self.model.device)

        # Predict
        outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
                                      **kwargs)

        # Decode outputs
        predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return predictions