"""
    Get data and adapt it for training
    -----------
    - nettoyage de l'encodage
    - Ajout de token <START> et <END>
    TO DO :
    - Nettoyage des contractions
    - enlever les \xad
    - enlever ponctuation et () []
    - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
    Création d'un Vectoriserà partir du vocabulaire :

"""
import string
from collections import Counter

import pandas as pd
import torch
from nltk import word_tokenize

# nltk.download('punkt')


class Data:
    """
    A class used to get data from file
    ...

    Attributes
    ----------
    path : str
        the path to the file containing the data

    Methods
    -------
    open()
        open the jsonl file with pandas
    clean_data(text_type)
        clean the data got by opening the file and adds <start> and
        <end> tokens depending on the text_type
    get_words()
        get the dataset vocabulary
    make_dataset()
        create a dataset with cleaned data
    """

    def __init__(self, path: str) -> None:
        self.path = path

    def open(self) -> pd.DataFrame:
        """
        Open the file containing the data
        """
        return pd.read_json(path_or_buf=self.path, lines=True)

    def clean_data(self, text_type: str) -> list:
        """
        Clean data from encoding error, punctuation, etc...
        To Do :
        #nettoyer les données

        Parameters
        ----------
        text_type : str
            allow to differenciate between 'text' and 'summary'
            to add <start> and <end> tokens to summaries

        Returns
        ----------
        list of list
            list of tokenised texts

        """
        dataset = self.open()

        texts = dataset[text_type]
        texts = texts.str.encode("cp1252", "ignore")
        texts = texts.str.decode("utf-8", "ignore")

        tokenized_texts = []
        # - Nettoyage des contractions
        # - enlever les \xad
        # text.translate(str.maketrans('', '', string.punctuation))
        # - enlever ponctuation et () []
        # - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
        for text in texts:
            text = text.translate(str.maketrans("", "", string.punctuation))
            text = word_tokenize(text)
            tokenized_texts.append(text)

        if text_type == "summary":
            return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
        return tokenized_texts

    def pad_sequence(self):
        """
        pad summary with empty token
        """
        texts = self.clean_data("text")
        summaries = self.clean_data("summary")
        padded_summary = []
        for text, summary in zip(texts, summaries):
            if len(summary) != len(text):
                summary += ["<empty>"] * (len(text) - len(summary))
            padded_summary.append(summary)
        return texts, padded_summary

    def get_words(self) -> list:
        """
        Create a dictionnary of the data vocabulary
        """
        texts, summaries = self.clean_data("text"), self.clean_data("summary")
        text_words = [word for text in texts for word in text]
        summary_words = [word for text in summaries for word in text]
        return text_words + summary_words

    def make_dataset(self) -> pd.DataFrame:
        """
        Create a Pandas Dataframe with cleaned data
        --------------------
        param: self: Data
        return: pd.DataFrame
        """
        texts, summaries = self.clean_data("text"), self.clean_data("summary")
        return pd.DataFrame(list(zip(texts, summaries)), columns=["text", "summary"])


class Vectoriser:
    """
    A class used to vectorise data
    ...

    Attributes
    ----------
    vocab : list
        list of the vocab

    Methods
    -------
    encode(tokens)
        transforms a list of tokens to their corresponding idx
        in form of troch tensor
    decode(word_idx_tensor)
        converts a tensor to a list of tokens
    vectorize(row)
        encode an entire row from the dataset
    """

    def __init__(self, vocab) -> None:
        self.vocab = vocab
        self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
        self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
        self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}

    def encode(self, tokens) -> torch.tensor:
        """
        Encode une phrase selon les mots qu'elle contient
        selon les mots contenus dans le dictionnaire.
        À NOTER :
        Si un mot n'est pas contenu dans le dictionnaire,
        associe un index fixe au mot qui sera ignoré au décodage.
        ---------
        :params: tokens : list
            les mots de la phrase sous forme de liste
        :return: words_idx : tensor
            Un tensor contenant les index des mots de la phrase
        """
        if type(tokens) == list:
            words_idx = torch.tensor(
                [
                    self.token_to_idx.get(t.lower(), len(self.token_to_idx))
                    for t in tokens
                ],
                dtype=torch.long,
            )

        # Permet d'encoder mots par mots
        elif type(tokens) == str:
            words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))

        return words_idx

    def decode(self, words_idx_tensor) -> list:
        """
        Decode une phrase selon le procédé inverse que la fonction encode
        """
        words_idx_tensor = words_idx_tensor.argmax(dim=-1)
        idxs = words_idx_tensor.tolist()
        if type(idxs) == int:
            words = [self.idx_to_token[idxs]]
        else:
            words = []
            for idx in idxs:
                if idx != len(self.idx_to_token):
                    words.append(self.idx_to_token[idx])
        return words

    def beam_search(self, words_idx_tensor) -> list:
        pass

    def vectorize(self, row) -> torch.tensor:
        """
        Encode les données d'une ligne du dataframe
        ----------
        :params: row : dataframe
            une ligne du dataframe (un coupe texte-résumé)
        :returns: text_idx : tensor
            le tensor correspondant aux mots du textes
        :returns: summary_idx: tensor
            le tensr correspondant aux mots du résumé
        """
        text_idx = self.encode(row["text"])
        summary_idx = self.encode(row["summary"])
        return (text_idx, summary_idx)