"""
    Get data and adapt it for training
    -----------
    - nettoyage de l'encodage
    - Ajout de token <START> et <END>
    TO DO :
    - Nettoyage des contractions
    - enlever les \xad
    - enlever ponctuation et () []
    - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
    Création d'un Vectoriserà partir du vocabulaire :

"""
import pickle
import string
from collections import Counter

import pandas as pd
import torch


class Data(torch.utils.data.Dataset):
    """
    A class used to get data from file
    ...

    Attributes
    ----------
    path : str
        the path to the file containing the data

    Methods
    -------
    open()
        open the jsonl file with pandas
    clean_data(text_type)
        clean the data got by opening the file and adds <start> and
        <end> tokens depending on the text_type
    get_words()
        get the dataset vocabulary
    """

    def __init__(self, path: str, transform=None) -> None:
        self.path = path
        self.data = pd.read_json(path_or_buf=self.path, lines=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        text = row["text"].translate(
            str.maketrans(
                "", "", string.punctuation)).split()
        summary = (
            row["summary"].translate(
                str.maketrans(
                    "",
                    "",
                    string.punctuation)).split())
        summary = ["<start>", *summary, "<end>"]
        sample = {"text": text, "summary": summary}

        if self.transform:
            sample = self.transform(sample)

        return sample

    def open(self) -> pd.DataFrame:
        """
        Open the file containing the data
        """
        return pd.read_json(path_or_buf=self.path, lines=True)

    def clean_data(self, text_type: str) -> list:
        """
        Clean data from encoding error, punctuation, etc...
        To Do :
        #nettoyer les données

        Parameters
        ----------
        text_type : str
            allow to differenciate between 'text' and 'summary'
            to add <start> and <end> tokens to summaries

        Returns
        ----------
        list of list
            list of tokenised texts

        """
        dataset = self.open()

        texts = dataset[text_type]
        texts = texts.str.encode("cp1252", "ignore")
        texts = texts.str.decode("utf-8", "ignore")

        tokenized_texts = []
        # - Nettoyage des contractions
        # - enlever les \xad
        # text.translate(str.maketrans('', '', string.punctuation))
        # - enlever ponctuation et () []
        # - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
        for text in texts:
            text = text.translate(str.maketrans("", "", string.punctuation))
            text = text.split()
            tokenized_texts.append(text)

        if text_type == "summary":
            return [["<start>", *summary, "<end>"]
                    for summary in tokenized_texts]
        return tokenized_texts

    def get_words(self) -> list:
        """
        Create a dictionnary of the data vocabulary
        """
        texts, summaries = self.clean_data("text"), self.clean_data("summary")
        text_words = [word for text in texts for word in text]
        summary_words = [word for text in summaries for word in text]
        return text_words + summary_words


def pad_collate(data):
    text_batch = [element[0] for element in data]
    summary_batch = [element[1] for element in data]
    max_len = max([len(element) for element in summary_batch + text_batch])
    text_batch = [
        torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
        for element in text_batch
    ]
    summary_batch = [
        torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
        for element in summary_batch
    ]
    return text_batch, summary_batch


class Vectoriser:
    """
    A class used to vectorise data
    ...

    Attributes
    ----------
    vocab : list
        list of the vocab

    Methods
    -------
    encode(tokens)
        transforms a list of tokens to their corresponding idx
        in form of troch tensor
    decode(word_idx_tensor)
        converts a tensor to a list of tokens
    vectorize(row)
        encode an entire row from the dataset
    """

    def __init__(self, vocab=None) -> None:
        self.vocab = vocab
        self.word_count = Counter(word.lower().strip(",.\\-")
                                  for word in self.vocab)
        self.idx_to_token = sorted(
            [t for t, c in self.word_count.items() if c > 1])
        self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}

    def load(self, path):
        with open(path, "rb") as file:
            self.vocab = pickle.load(file)
            self.word_count = Counter(
                word.lower().strip(",.\\-") for word in self.vocab
            )
            self.idx_to_token = sorted(
                [t for t, c in self.word_count.items() if c > 1])
            self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}

    def save(self, path):
        with open(path, "wb") as file:
            pickle.dump(self.vocab, file)

    def encode(self, tokens) -> torch.tensor:
        """
        Encode une phrase selon les mots qu'elle contient
        selon les mots contenus dans le dictionnaire.
        À NOTER :
        Si un mot n'est pas contenu dans le dictionnaire,
        associe un index fixe au mot qui sera ignoré au décodage.
        ---------
        :params: tokens : list
            les mots de la phrase sous forme de liste
        :return: words_idx : tensor
            Un tensor contenant les index des mots de la phrase
        """
        if isinstance(tokens, list):
            words_idx = torch.tensor(
                [
                    self.token_to_idx.get(t.lower(), len(self.token_to_idx))
                    for t in tokens
                ],
                dtype=torch.long,
            )

        # Permet d'encoder mots par mots
        elif isinstance(tokens, str):
            words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))

        return words_idx

    def decode(self, words_idx_tensor) -> list:
        """
        Decode une phrase selon le procédé inverse que la fonction encode
        """

        idxs = words_idx_tensor.tolist()
        if isinstance(idxs, int):
            words = [self.idx_to_token[idxs]]
        else:
            words = []
            for idx in idxs:
                if idx != len(self.idx_to_token):
                    words.append(self.idx_to_token[idx])
        return words

    def __call__(self, row) -> torch.tensor:
        """
        Encode les données d'une ligne du dataframe
        ----------
        :params: row : dataframe
            une ligne du dataframe (un coupe texte-résumé)
        :returns: text_idx : tensor
            le tensor correspondant aux mots du textes
        :returns: summary_idx: tensor
            le tensr correspondant aux mots du résumé
        """
        text_idx = self.encode(row["text"])
        summary_idx = self.encode(row["summary"])
        return (text_idx, summary_idx)