"""
Loading the diacritization dataset
"""

import os

from diacritization_evaluation import util
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

from .config_manager import ConfigManager

BASIC_HARAQAT = {
    "َ": "Fatha              ",
    "ً": "Fathatah           ",
    "ُ": "Damma              ",
    "ٌ": "Dammatan           ",
    "ِ": "Kasra              ",
    "ٍ": "Kasratan           ",
    "ْ": "Sukun              ",
    "ّ": "Shaddah            ",
}


class DiacritizationDataset(Dataset):
    """
    The diacritization dataset
    """

    def __init__(self, config_manager: ConfigManager, list_ids, data):
        "Initialization"
        self.list_ids = list_ids
        self.data = data
        self.text_encoder = config_manager.text_encoder
        self.config = config_manager.config

    def __len__(self):
        "Denotes the total number of samples"
        return len(self.list_ids)

    def preprocess(self, book):
        out = ""
        i = 0
        while i < len(book):
            if i < len(book) - 1:
                if book[i] in BASIC_HARAQAT and book[i + 1] in BASIC_HARAQAT:
                    i += 1
                    continue
            out += book[i]
            i += 1
        return out

    def __getitem__(self, index):
        "Generates one sample of data"
        # Select sample
        id = self.list_ids[index]
        if self.config["is_data_preprocessed"]:
            data = self.data.iloc[id]
            inputs = torch.Tensor(self.text_encoder.input_to_sequence(data[1]))
            targets = torch.Tensor(
                self.text_encoder.target_to_sequence(
                    data[2].split(self.config["diacritics_separator"])
                )
            )
            return inputs, targets, data[0]

        data = self.data[id]
        non_cleaned = data

        data = self.text_encoder.clean(data)
        data = data[: self.config["max_sen_len"]]
        text, inputs, diacritics = util.extract_haraqat(data)

        inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
        diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))

        return inputs, diacritics, text


def collate_fn(data):
    """
    Padding the input and output sequences
    """

    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    data.sort(key=lambda x: len(x[0]), reverse=True)

    # separate source and target sequences
    src_seqs, trg_seqs, original = zip(*data)

    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)

    batch = {
        "original": original,
        "src": src_seqs,
        "target": trg_seqs,
        "lengths": torch.LongTensor(src_lengths),  # src_lengths = trg_lengths
    }
    return batch


def load_training_data(config_manager: ConfigManager, loader_parameters):
    """
    Loading the training data using pandas
    """

    if not config_manager.config["load_training_data"]:
        return []

    path = os.path.join(config_manager.data_dir, "train.csv")
    if config_manager.config["is_data_preprocessed"]:
        train_data = pd.read_csv(
            path,
            encoding="utf-8",
            sep=config_manager.config["data_separator"],
            nrows=config_manager.config["n_training_examples"],
            header=None,
        )

        # train_data = train_data[train_data[0] <= config_manager.config["max_len"]]
        training_set = DiacritizationDataset(
            config_manager, train_data.index, train_data
        )
    else:
        with open(path, encoding="utf8") as file:
            train_data = file.readlines()
            train_data = [
                text
                for text in train_data
                if len(text) <= config_manager.config["max_len"] and len(text) > 0
            ]
        training_set = DiacritizationDataset(
            config_manager, [idx for idx in range(len(train_data))], train_data
        )

    train_iterator = DataLoader(
        training_set, collate_fn=collate_fn, **loader_parameters
    )

    print(f"Length of training iterator = {len(train_iterator)}")
    return train_iterator


def load_test_data(config_manager: ConfigManager, loader_parameters):
    """
    Loading the test data using pandas
    """
    if not config_manager.config["load_test_data"]:
        return []
    test_file_name = config_manager.config.get("test_file_name", "test.csv")
    path = os.path.join(config_manager.data_dir, test_file_name)
    if config_manager.config["is_data_preprocessed"]:
        test_data = pd.read_csv(
            path,
            encoding="utf-8",
            sep=config_manager.config["data_separator"],
            nrows=config_manager.config["n_test_examples"],
            header=None,
        )
        # test_data = test_data[test_data[0] <= config_manager.config["max_len"]]
        test_dataset = DiacritizationDataset(config_manager, test_data.index, test_data)
    else:
        with open(path, encoding="utf8") as file:
            test_data = file.readlines()
        max_len = config_manager.config["max_len"]
        test_data = [text[:max_len] for text in test_data]
        test_dataset = DiacritizationDataset(
            config_manager, [idx for idx in range(len(test_data))], test_data
        )

    test_iterator = DataLoader(test_dataset, collate_fn=collate_fn, **loader_parameters)

    print(f"Length of test iterator = {len(test_iterator)}")
    return test_iterator


def load_validation_data(config_manager: ConfigManager, loader_parameters):
    """
    Loading the validation data using pandas
    """

    if not config_manager.config["load_validation_data"]:
        return []
    path = os.path.join(config_manager.data_dir, "eval.csv")
    if config_manager.config["is_data_preprocessed"]:
        valid_data = pd.read_csv(
            path,
            encoding="utf-8",
            sep=config_manager.config["data_separator"],
            nrows=config_manager.config["n_validation_examples"],
            header=None,
        )
        valid_data = valid_data[valid_data[0] <= config_manager.config["max_len"]]
        valid_dataset = DiacritizationDataset(
            config_manager, valid_data.index, valid_data
        )
    else:
        with open(path, encoding="utf8") as file:
            valid_data = file.readlines()

        max_len = config_manager.config["max_len"]
        valid_data = [text[:max_len] for text in valid_data]
        valid_dataset = DiacritizationDataset(
            config_manager, [idx for idx in range(len(valid_data))], valid_data
        )

    valid_iterator = DataLoader(
        valid_dataset, collate_fn=collate_fn, **loader_parameters
    )

    print(f"Length of valid iterator = {len(valid_iterator)}")
    return valid_iterator


def load_iterators(config_manager: ConfigManager):
    """
    Load the data iterators
    Args:
    """
    params = {
        "batch_size": config_manager.config["batch_size"],
        "shuffle": True,
        "num_workers": 2,
    }
    train_iterator = load_training_data(config_manager, loader_parameters=params)
    valid_iterator = load_validation_data(config_manager, loader_parameters=params)
    test_iterator = load_test_data(config_manager, loader_parameters=params)
    return train_iterator, test_iterator, valid_iterator