from enum import Enum
import os
from pathlib import Path
import shutil
import subprocess
from typing import Any, Dict

import ruamel.yaml
import torch

from poetry_diacritizer.models.baseline import BaseLineModel
from poetry_diacritizer.models.cbhg import CBHGModel
from poetry_diacritizer.models.gpt import GPTModel
from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq
from poetry_diacritizer.models.tacotron_based import (
    Decoder as TacotronDecoder,
    Encoder as TacotronEncoder,
    Tacotron,
)

from poetry_diacritizer.options import AttentionType, LossType, OptimizerType
from poetry_diacritizer.util.text_encoders import (
    ArabicEncoderWithStartSymbol,
    BasicArabicEncoder,
    TextEncoder,
)


class ConfigManager:
    """Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager"""

    def __init__(self, config_path: str, model_kind: str):
        available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"]
        if model_kind not in available_models:
            raise TypeError(f"model_kind must be in {available_models}")
        self.config_path = Path(config_path)
        self.model_kind = model_kind
        self.yaml = ruamel.yaml.YAML()
        self.config: Dict[str, Any] = self._load_config()
        self.git_hash = self._get_git_hash()
        self.session_name = ".".join(
            [
                self.config["data_type"],
                self.config["session_name"],
                f"{model_kind}",
            ]
        )

        self.data_dir = Path(
            os.path.join(self.config["data_directory"], self.config["data_type"])
        )
        self.base_dir = Path(
            os.path.join(self.config["log_directory"], self.session_name)
        )
        self.log_dir = Path(os.path.join(self.base_dir, "logs"))
        self.prediction_dir = Path(os.path.join(self.base_dir, "predictions"))
        self.plot_dir = Path(os.path.join(self.base_dir, "plots"))
        self.models_dir = Path(os.path.join(self.base_dir, "models"))
        if "sp_model_path" in self.config:
            self.sp_model_path = self.config["sp_model_path"]
        else:
            self.sp_model_path = None
        self.text_encoder: TextEncoder = self.get_text_encoder()
        self.config["len_input_symbols"] = len(self.text_encoder.input_symbols)
        self.config["len_target_symbols"] = len(self.text_encoder.target_symbols)
        if self.model_kind in ["seq2seq", "tacotron_based"]:
            self.config["attention_type"] = AttentionType[self.config["attention_type"]]
        self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]]

    def _load_config(self):
        with open(self.config_path, "rb") as model_yaml:
            _config = self.yaml.load(model_yaml)
        return _config

    @staticmethod
    def _get_git_hash():
        try:
            return (
                subprocess.check_output(["git", "describe", "--always"])
                .strip()
                .decode()
            )
        except Exception as e:
            print(f"WARNING: could not retrieve git hash. {e}")

    def _check_hash(self):
        try:
            git_hash = (
                subprocess.check_output(["git", "describe", "--always"])
                .strip()
                .decode()
            )
            if self.config["git_hash"] != git_hash:
                print(
                    f"""WARNING: git hash mismatch. Current: {git_hash}.
                    Config hash: {self.config['git_hash']}"""
                )
        except Exception as e:
            print(f"WARNING: could not check git hash. {e}")

    @staticmethod
    def _print_dict_values(values, key_name, level=0, tab_size=2):
        tab = level * tab_size * " "
        print(tab + "-", key_name, ":", values)

    def _print_dictionary(self, dictionary, recursion_level=0):
        for key in dictionary.keys():
            if isinstance(key, dict):
                recursion_level += 1
                self._print_dictionary(dictionary[key], recursion_level)
            else:
                self._print_dict_values(
                    dictionary[key], key_name=key, level=recursion_level
                )

    def print_config(self):
        print("\nCONFIGURATION", self.session_name)
        self._print_dictionary(self.config)

    def update_config(self):
        self.config["git_hash"] = self._get_git_hash()

    def dump_config(self):
        self.update_config()
        _config = {}
        for key, val in self.config.items():
            if isinstance(val, Enum):
                _config[key] = val.name
            else:
                _config[key] = val
        with open(self.base_dir / "config.yml", "w") as model_yaml:
            self.yaml.dump(_config, model_yaml)

    def create_remove_dirs(
        self,
        clear_dir: bool = False,
        clear_logs: bool = False,
        clear_weights: bool = False,
        clear_all: bool = False,
    ):
        self.base_dir.mkdir(exist_ok=True, parents=True)
        self.plot_dir.mkdir(exist_ok=True)
        self.prediction_dir.mkdir(exist_ok=True)
        if clear_dir:
            delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])")
            if delete == "y":
                shutil.rmtree(self.log_dir, ignore_errors=True)
                shutil.rmtree(self.models_dir, ignore_errors=True)
        if clear_logs:
            delete = input(f"Delete {self.log_dir}? (y/[n])")
            if delete == "y":
                shutil.rmtree(self.log_dir, ignore_errors=True)
        if clear_weights:
            delete = input(f"Delete {self.models_dir}? (y/[n])")
            if delete == "y":
                shutil.rmtree(self.models_dir, ignore_errors=True)
        self.log_dir.mkdir(exist_ok=True)
        self.models_dir.mkdir(exist_ok=True)

    def get_last_model_path(self):
        """
        Given a checkpoint, get the last save model name
        Args:
        checkpoint (str): the path where models are saved
        """
        models = os.listdir(self.models_dir)
        models = [model for model in models if model[-3:] == ".pt"]
        if len(models) == 0:
            return None
        _max = max(int(m.split(".")[0].split("-")[0]) for m in models)
        model_name = f"{_max}-snapshot.pt"
        last_model_path = os.path.join(self.models_dir, model_name)

        return last_model_path

    def load_model(self, model_path: str = None):
        """
        loading a model from path
        Args:
        checkpoint (str): the path to the model
        name (str): the name of the model, which is in the path
        model (Tacotron): the model  to load its save state
        optimizer: the optimizer to load its saved state
        """

        model = self.get_model()

        with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file:
            file.write(str(model))

        if model_path is None:
            last_model_path = self.get_last_model_path()
            if last_model_path is None:
                return model, 1
        else:
            last_model_path = model_path

        saved_model = torch.load(last_model_path)
        out = model.load_state_dict(saved_model["model_state_dict"])
        print(out)
        global_step = saved_model["global_step"] + 1
        return model, global_step

    def get_model(self, ignore_hash=False):
        if not ignore_hash:
            self._check_hash()
        if self.model_kind == "cbhg":
            return self.get_cbhg()

        elif self.model_kind == "seq2seq":
            return self.get_seq2seq()

        elif self.model_kind == "tacotron_based":
            return self.get_tacotron_based()

        elif self.model_kind == "baseline":
            return self.get_baseline()

        elif self.model_kind == "gpt":
            return self.get_gpt()

    def get_gpt(self):
        model = GPTModel(
            self.config["base_model_path"],
            freeze=self.config["freeze"],
            n_layer=self.config["n_layer"],
            use_lstm=self.config["use_lstm"],
        )
        return model

    def get_baseline(self):
        model = BaseLineModel(
            embedding_dim=self.config["embedding_dim"],
            inp_vocab_size=self.config["len_input_symbols"],
            targ_vocab_size=self.config["len_target_symbols"],
            layers_units=self.config["layers_units"],
            use_batch_norm=self.config["use_batch_norm"],
        )

        return model

    def get_cbhg(self):
        model = CBHGModel(
            embedding_dim=self.config["embedding_dim"],
            inp_vocab_size=self.config["len_input_symbols"],
            targ_vocab_size=self.config["len_target_symbols"],
            use_prenet=self.config["use_prenet"],
            prenet_sizes=self.config["prenet_sizes"],
            cbhg_gru_units=self.config["cbhg_gru_units"],
            cbhg_filters=self.config["cbhg_filters"],
            cbhg_projections=self.config["cbhg_projections"],
            post_cbhg_layers_units=self.config["post_cbhg_layers_units"],
            post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"],
        )

        return model

    def get_seq2seq(self):
        encoder = Seq2SeqEncoder(
            embedding_dim=self.config["encoder_embedding_dim"],
            inp_vocab_size=self.config["len_input_symbols"],
            layers_units=self.config["encoder_units"],
            use_batch_norm=self.config["use_batch_norm"],
        )

        decoder = TacotronDecoder(
            self.config["len_target_symbols"],
            start_symbol_id=self.text_encoder.start_symbol_id,
            embedding_dim=self.config["decoder_embedding_dim"],
            encoder_dim=self.config["encoder_dim"],
            decoder_units=self.config["decoder_units"],
            decoder_layers=self.config["decoder_layers"],
            attention_type=self.config["attention_type"],
            attention_units=self.config["attention_units"],
            is_attention_accumulative=self.config["is_attention_accumulative"],
            use_prenet=self.config["use_decoder_prenet"],
            prenet_depth=self.config["decoder_prenet_depth"],
            teacher_forcing_probability=self.config["teacher_forcing_probability"],
        )

        model = Tacotron(encoder=encoder, decoder=decoder)

        return model

    def get_tacotron_based(self):
        encoder = TacotronEncoder(
            embedding_dim=self.config["encoder_embedding_dim"],
            inp_vocab_size=self.config["len_input_symbols"],
            prenet_sizes=self.config["prenet_sizes"],
            use_prenet=self.config["use_encoder_prenet"],
            cbhg_gru_units=self.config["cbhg_gru_units"],
            cbhg_filters=self.config["cbhg_filters"],
            cbhg_projections=self.config["cbhg_projections"],
        )

        decoder = TacotronDecoder(
            self.config["len_target_symbols"],
            start_symbol_id=self.text_encoder.start_symbol_id,
            embedding_dim=self.config["decoder_embedding_dim"],
            encoder_dim=self.config["encoder_dim"],
            decoder_units=self.config["decoder_units"],
            decoder_layers=self.config["decoder_layers"],
            attention_type=self.config["attention_type"],
            attention_units=self.config["attention_units"],
            is_attention_accumulative=self.config["is_attention_accumulative"],
            use_prenet=self.config["use_decoder_prenet"],
            prenet_depth=self.config["decoder_prenet_depth"],
            teacher_forcing_probability=self.config["teacher_forcing_probability"],
        )

        model = Tacotron(encoder=encoder, decoder=decoder)

        return model

    def get_text_encoder(self):
        """Getting the class of TextEncoder from config"""
        if self.config["text_cleaner"] not in [
            "basic_cleaners",
            "valid_arabic_cleaners",
            None,
        ]:
            raise Exception(f"cleaner is not known {self.config['text_cleaner']}")

        if self.config["text_encoder"] == "BasicArabicEncoder":
            text_encoder = BasicArabicEncoder(
                cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
            )
        elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol":
            text_encoder = ArabicEncoderWithStartSymbol(
                cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
            )
        else:
            raise Exception(
                f"the text encoder is not found {self.config['text_encoder']}"
            )

        return text_encoder

    def get_loss_type(self):
        try:
            loss_type = LossType[self.config["loss_type"]]
        except:
            raise Exception(f"The loss type is not correct {self.config['loss_type']}")
        return loss_type


if __name__ == "__main__":
    config_path = "config/tacotron-base-config.yml"
    model_kind = "tacotron"
    config = ConfigManager(config_path=config_path, model_kind=model_kind)