from typing import List

import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PerceiverTokenizer


def _map_outputs(predictions):
    """
    Map model outputs to classes.

    :param predictions: model ouptut batch
    :return:
    """

    labels = [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]
    classes = []
    for i, example in enumerate(predictions):
        out_batch = []
        for j, category in enumerate(example):
            out_batch.append(labels[j]) if category > 0.5 else None
        classes.append(out_batch)
    return classes


class MultiLabelPipeline:
    """
    Multi label classification pipeline.
    """

    def __init__(self, model_path):
        """
        Init MLC pipeline.

        :param model_path: model to use
        """

        # Init attributes
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if self.device == 'cuda':
            self.model = torch.load(model_path).eval().to(self.device)
        else:
            self.model = torch.load(model_path, map_location=torch.device('cpu')).eval().to(self.device)
        self.tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')

    def __call__(self, dataset, batch_size: int = 4):
        """
        Processing pipeline.

        :param dataset: dataset
        :return:
        """

        # Tokenize inputs
        dataset = dataset.map(lambda row: self.tokenizer(row['text'], padding="max_length", truncation=True),
                              batched=True, remove_columns=['text'], desc='Tokenizing')
        dataset.set_format('torch', columns=['input_ids', 'attention_mask'])
        dataloader = DataLoader(dataset, batch_size=batch_size)

        # Define output classes
        classes = []
        mem_logs = []

        with tqdm(dataloader, unit='batches') as progression:
            for batch in progression:
                progression.set_description('Inference')
                # Forward
                outputs = self.model(inputs=batch['input_ids'].to(self.device),
                                     attention_mask=batch['attention_mask'].to(self.device), )

                # Outputs
                predictions = outputs.logits.cpu().detach().numpy()

                # Map predictions to classes
                batch_classes = _map_outputs(predictions)

                for row in batch_classes:
                    classes.append(row)

                # Retrieve memory usage
                memory = round(torch.cuda.memory_reserved(self.device) / 1e9, 2)
                mem_logs.append(memory)

                # Update pbar
                progression.set_postfix(memory=f"{round(sum(mem_logs) / len(mem_logs), 2)}Go")

        return classes


def inputs_to_dataset(inputs: List[str]):
    """
    Convert a list of strings to a dataset object.

    :param inputs: list of strings
    :return:
    """

    inputs = {'text': [input for input in inputs]}

    return Dataset.from_dict(inputs)