"""
Training the network
"""
import datetime
import logging
import random
import time
from typing import Sequence, Tuple

import torch

import dataloader
from model import Decoder, Encoder, EncoderDecoderModel

# logging INFO, WARNING, ERROR, CRITICAL, DEBUG
logging.basicConfig(level=logging.INFO)
logging.disable(level=10)


def train_network(
    model: torch.nn.Module,
    train_set: Sequence[Tuple[torch.tensor, torch.Tensor]],
    dev_set: Sequence[Tuple[torch.tensor, torch.Tensor]],
    epochs: int,
    clip: int = 1,
):
    """
    Train the EncoderDecoderModel network for a given number of epoch
    -----------
    Parameters
        model: torch.nn.Module
            EncoderDecoderModel defined in model.py
        train_set: Sequence[Tuple[torch.tensor, torch.tensor]]
            tuple of vectorized (text, summary) from the training set
        dev_set: Sequence[Tuple[torch.tensor, torch.tensor]]
            tuple of vectorized (text, summary) for the dev set
        epochs: int
            the number of epochs to train on
        clip: int
            no idea
    Return
        None
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print("Device check. You are using:", model.device)

    # with torch.no_grad():

    optim = torch.optim.Adam(model.parameters(), lr=0.01)

    print("Epoch\ttrain loss\tdev accuracy\tcompute time")

    for epoch_n in range(epochs):
        # Tell the model it's in train mode for layers designed to
        # behave differently in train or evaluation
        # https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch
        model.train()

        # To get the computing time per epoch
        epoch_start_time = time.time()

        # To get the model accuracy per epoch
        epoch_loss = 0.0
        epoch_length = 0

        # Iterates over all the text, summary tuples
        for source, target in train_set:
            source = source.to(device)
            target = target.to(device)

            # DEBUG Block
            # logging.debug("TRAIN")
            # logging.debug(f"cuda available ? {torch.cuda.is_available()}")
            # logging.debug(f"Source sur cuda ? {source.is_cuda}")
            # logging.debug(f"Target sur cuda ? {target.is_cuda}")

            out = model(source).to(device)
            logging.debug(f"outputs = {out.shape}")

            target = torch.nn.functional.pad(
                target, (0, len(out) - len(target)), value=-100
            )

            # logging.debug(f"prediction : {vectoriser.decode(output_predictions)}")
            loss = torch.nn.functional.nll_loss(out, target).to(device)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optim.step()

            epoch_loss += loss.item()
            epoch_length += source.shape[0]

        # To check the model accuracy on new data
        dev_correct = 0
        dev_total = 0

        # Iterates over text, summary tuple from dev
        for source, target in dev_set:
            # We here want to evaluate the model
            # so we're switching to evaluation mode
            model.eval()

            source = source.to(device)
            target = target.to(device)

            # We compute the result
            output = model(source).to(device)

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            logging.debug(f"dev output : {output.shape}")
            target = target[1:].view(-1)
            # To compare the output with the target,
            # they have to be of same length so we're
            # padding the target with -100 idx that will
            # be ignored by the nll_loss function
            target = torch.nn.functional.pad(
                target, (0, len(output) - len(target)), value=-100
            )
            dev_loss = torch.nn.functional.nll_loss(output, target)
            dev_correct += dev_loss.item()
            dev_total += source.shape[0]

        # Compute of the epoch training time
        epoch_compute_time = time.time() - epoch_start_time

        print(
            f"{epoch_n}\t{epoch_loss/epoch_length:.5}\t{abs(dev_correct/dev_total):.2%}\t\t{datetime.timedelta(seconds=epoch_compute_time)}"
        )


def predict(model, tokens: Sequence[str]) -> Sequence[str]:
    """Predict the POS for a tokenized sequence"""
    words_idx = vectoriser.encode(tokens).to(device)
    # Pas de calcul de gradient ici : c'est juste pour les prédictions
    with torch.no_grad():
        # equivalent to model(input) when called out of class
        out = model(words_idx).to(device)
    out_predictions = out.to(device)
    print(out_predictions)
    out_predictions = out_predictions.argmax(dim=-1)
    return vectoriser.decode(out_predictions)


if __name__ == "__main__":
    train_dataset = dataloader.Data("data/train_extract.jsonl")
    words = train_dataset.get_words()
    vectoriser = dataloader.Vectoriser(words)

    train_dataset = dataloader.Data(
        "data/train_extract.jsonl",
        transform=vectoriser)
    dev_dataset = dataloader.Data(
        "data/dev_extract.jsonl",
        transform=vectoriser)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=dataloader.pad_collate)

    dev_dataloader = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=dataloader.pad_collate)

    for i_batch, batch in enumerate(train_dataloader):
        print(i_batch, batch[0], batch[1])

    ### NEURAL NETWORK ###
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device check. You are using:", device)

    ### RÉSEAU ENTRAÎNÉ ###
    # Pour s'assurer que les résultats seront les mêmes à chaque run du
    # notebook
    torch.use_deterministic_algorithms(True)
    torch.manual_seed(0)
    random.seed(0)

    # On peut également entraîner encoder séparemment
    encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
    decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)

    trained_classifier = EncoderDecoderModel(
        encoder, decoder, vectoriser, device).to(device)

    print(next(trained_classifier.parameters()).device)
    # print(train_dataset.is_cuda)

    train_network(
        trained_classifier,
        train_dataset,
        dev_dataset,
        2,
    )

    torch.save(trained_classifier.state_dict(), "model/model.pt")
    vectoriser.save("model/vocab.pkl")

    print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
    print(
        f"test prediction : {predict(trained_classifier, vectoriser.decode(dev_dataset[6][0]))}"
    )