"""
    Defines the Encoder, Decoder and Sequence to Sequence models
    used in this projet
"""
import logging

import torch

import dataloader

logging.basicConfig(level=logging.DEBUG)

data1 = dataloader.Data("data/train_extract.jsonl")
words = data1.get_words()
vectoriser = dataloader.Vectoriser(words)


class Encoder(torch.nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embeddings_dim: int,
        hidden_size: int,
        dropout: int,
        device,
    ):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.device = device
        # On ajoute un mot supplémentaire au vocabulaire :
        # on s'en servira pour les mots inconnus
        self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
        self.embeddings.to(device)
        self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
        # Comme on va calculer la log-vraisemblance,
        # c'est le log-softmax qui nous intéresse
        self.dropout = torch.nn.Dropout(dropout)
        self.dropout.to(self.device)
        # Dropout

    def forward(self, inpt):
        inpt.to(self.device)
        emb = self.dropout(self.embeddings(inpt)).to(self.device)
        emb = emb.to(self.device)

        output, (hidden, cell) = self.hidden(emb)
        output.to(self.device)
        hidden = hidden.to(self.device)
        cell = cell.to(self.device)

        return hidden, cell


class Decoder(torch.nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embeddings_dim: int,
        hidden_size: int,
        dropout: int,
        device,
    ):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.device = device
        # On ajoute un mot supplémentaire au vocabulaire :
        # on s'en servira pour les mots inconnus
        self.vocab_size = vocab_size
        self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
        self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
        self.output = torch.nn.Linear(hidden_size, vocab_size)
        # Comme on va calculer la log-vraisemblance,
        # c'est le log-softmax qui nous intéresse
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        input = input.to(self.device)
        emb = self.dropout(self.embeddings(input)).to(self.device)
        emb = emb.to(self.device)

        output, (hidden, cell) = self.hidden(emb, (hidden, cell))
        output = output.to(self.device)
        out = self.output(output.squeeze(0)).to(self.device)
        return out, hidden, cell


class EncoderDecoderModel(torch.nn.Module):
    def __init__(self, encoder, decoder, device):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, num_beams=3):
        # CHANGER LA TARGET LEN POUR QQCH DE MODULABLE
        target_len = int(1 * source.shape[0])  # Taille du texte que l'on recherche
        target_vocab_size = self.decoder.vocab_size  # Taille du mot

        # tensor to store decoder outputs
        outputs = torch.zeros(target_len, target_vocab_size).to(
            self.device
        )  # Instenciation d'une matrice de zeros de taille (taille du texte, taille du mot)
        outputs.to(
            self.device
        )  # Une idiosyncrasie de torch pour mettre le tensor sur le GPU

        # last hidden state of the encoder is used as the initial hidden state of the decoder
        source.to(
            self.device
        )  # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
        hidden, cell = self.encoder(source)  # Encode le texte sous forme de vecteur
        hidden.to(
            self.device
        )  # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
        cell.to(
            self.device
        )  # Une idiosyncrasie de torch pour mettre le tensor sur le GPU

        # first input to the decoder is the <start> token.
        input = vectoriser.encode("<start>")  # Mot de départ du MOdèle
        input.to(self.device)  # idiosyncrasie de torch pour mmettre sur GPU

        ### DÉBUT DE L'INSTANCIATION TEST ###
        # If you wonder, b stands for better
        values = None
        b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
        b_outputs.to(self.device)

        for i in range(
            1, target_len
        ):  # On va déterminer autant de mot que la taille du texte souhaité
            # insert input token embedding, previous hidden and previous cell states
            # receive output tensor (predictions) and new hidden and cell states.

            # replace predictions in a tensor holding predictions for each token
            # logging.debug(f"output : {output}")

            ####### DÉBUT DU BEAM SEARCH ##########
            if values is None:
                # On calcule une première fois les premières probabilité de mot après <start>
                output, hidden, cell = self.decoder(input, hidden, cell)
                output.to(self.device)
                b_hidden = hidden
                b_cell = cell

                # On choisi les k meilleurs scores pour choisir la meilleure probabilité
                # sur deux itérations ensuite
                values, indices = output.topk(num_beams, sorted=True)

            else:
                # On instancie le dictionnaire qui contiendra les scores pour chaque possibilité
                scores = {}

                # Pour chacune des meilleures valeurs, on va calculer l'output
                for value, indice in zip(values, indices):
                    indice.to(self.device)

                    # On calcule l'output
                    b_output, b_hidden, b_cell = self.decoder(indice, b_hidden, b_cell)

                    # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
                    # de force la probabilité du mot précédent à 0
                    b_output[indice] = torch.zeros(1)

                    # On choisit le meilleur résultat pour cette possibilité
                    highest_value = torch.log(b_output).max()

                    # On calcule le score des 2 itérations ensembles
                    score = highest_value * torch.log(value)
                    scores[score] = (b_output, b_hidden, b_cell)

                # On garde le meilleur score sur LES 2 ITÉRATIONS
                b_output, b_hidden, b_cell = scores.get(max(scores))

                # Et du coup on rempli la place de i-1 à la place de i
                b_outputs[i - 1] = b_output.to(self.device)

                # On instancies nos nouvelles valeurs pour la prochaine itération
                values, indices = b_output.topk(num_beams, sorted=True)

            ##################################

            # outputs[i] = output.to(self.device)
            # input = output.argmax(dim=-1).to(self.device)
            # input.to(self.device)

        # logging.debug(f"{vectoriser.decode(outputs.argmax(dim=-1))}")
        return b_outputs.to(self.device)