import os

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.utils import ContextManagers

from m4.training.setup_vision_model import vision_model_name_to_model
from m4.training.utils import (
    deepspeed_zero_init_disabled_context_manager,
    is_deepspeed_zero_init_enabled,
    load_state_dict_into_model,
)


# from pathlib import Path


class VLOOMPreTrainedModelBase(PreTrainedModel):
    # The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name)
    # and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init
    # So one solution is this:
    # a. replace  from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context
    # b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights
    #
    # there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3
    # 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model
    # 2. the other variant is to override _from_config method with our version that doesn't do zero.Init

    @classmethod
    def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype):
        # 1. fetch the pretrained vision model w/o zero.Init
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype)

        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        real_vision_model = vision_model_name_to_model(vision_model_name, vision_model)

        # 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model
        # by gradually gathering sharded weights and replacing with new weights
        if is_deepspeed_zero_init_enabled():
            state_dict = real_vision_model.state_dict()
            load_state_dict_into_model(model.vision_model, state_dict, start_prefix="")
        else:
            model.vision_model = real_vision_model

    @classmethod
    def from_config(cls, config, **kwargs):
        # torch_dtype is crucial for using the minimal amount of memory at load time
        torch_dtype = kwargs.get("torch_dtype", None)

        vision_model_name = config.vision_model_name
        vision_model_params = eval(config.vision_model_params)

        # 1. create an uninitialized vision_model to insert into the main model.
        # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
            vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)

        # 2. create the main class's model, passing the uninitialized vision_model to it
        model = cls(config, **kwargs)

        return model

    @classmethod
    def from_pretrained_models(cls, *args, **kwargs):
        """
        Use this method when creating a new vloom model that hasn't been yet trained and it'll be
        composed of 2 pre-trained models - hence `pretrained_models`.
        """

        return cls.from_pretrained(*args, **kwargs, new_model=True)

    @classmethod
    def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs):
        """
        Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub.
        For creating an untrained model use `pretrained_models` instead.
        """

        is_untrained_vloom_model = False
        is_pretrained_vloom_model_resumed = False
        is_pretrained_vloom_model_from_hub_or_path = False

        # we have 3 use cases:
        # 1. is_untrained_vloom_model - a totally new vloom model
        # 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a
        #    checkpoint (instantiate a random empty model in this case)
        # 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path
        if new_model:
            is_untrained_vloom_model = True
        elif is_resume:
            is_pretrained_vloom_model_resumed = True
        else:
            is_pretrained_vloom_model_from_hub_or_path = True

        # torch_dtype is crucial for using the minimal amount of memory at load time
        torch_dtype = kwargs.get("torch_dtype", None)

        # config is:
        # 1. either not passed and then we use the model's default config (used by tests)
        # 2. passed and in which case it's one of:
        #   2a. `PretrainedConfig` (a new m4 model)
        #   2b. path to a json config (an already pretrained m4 model, usually resumed training)
        config = kwargs.get("config", None)
        if config is None:
            config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False)
        elif not isinstance(config, PretrainedConfig):
            # adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920
            assert isinstance(config, os.PathLike)
            config_path = str(config)
            config = cls.config_class.from_pretrained(
                config_path,
                return_unused_kwargs=False,
                **kwargs,
            )

        vision_model_name = config.vision_model_name
        vision_model_params = eval(config.vision_model_params)

        # 1. create an uninitialized vision_model to insert into the main model.
        # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
            vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)

        # 2. create the vloom model
        if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path:
            model = super().from_pretrained(*model_args, **kwargs)
        elif is_pretrained_vloom_model_resumed:
            # in the case of resume under deepspeed we create an empty model, and get deepspeed
            # to load the weights from the checkpoint
            # but not all models have these keys so handle the case they don't have them
            _ = kwargs.pop("config", None)
            model = super().from_pretrained(None, config=config, state_dict={}, **kwargs)

        # 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights
        if is_untrained_vloom_model:
            cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype)

        return model


class DecoupledEmbedding(nn.Embedding):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
    If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
    """

    def __init__(
        self,
        num_embeddings,
        num_additional_embeddings,
        embedding_dim,
        partially_freeze=False,
        device=None,
        dtype=None,
        padding_idx=None,
        **kwargs,
    ) -> None:
        """
        num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
        partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.

        Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
        """
        if padding_idx is not None and padding_idx > num_embeddings:
            raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
        super().__init__(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            device=device,
            dtype=dtype,
            padding_idx=padding_idx,
            **kwargs,
        )
        self.num_embeddings = num_embeddings
        self.padding_idx = padding_idx
        self.num_additional_embeddings = num_additional_embeddings
        self.partially_freeze = partially_freeze

        if partially_freeze:
            self.weight.requires_grad_(False)

        if self.num_additional_embeddings > 0:
            self.additional_embedding = nn.Embedding(
                num_embeddings=self.num_additional_embeddings,
                embedding_dim=embedding_dim,
                device=device,
                dtype=dtype,
            )

    def forward(self, input_ids):
        """
        we have 2 embeddings, with different indices - one pretrained self.weight and another
        self.additional_embedding.weight that is being trained.

        in order to make a lookup of the input ids, we:
        1. find out the indices of the entries belonging to the 2nd embedding
        2. extract those values while subtracting the size of the first embedding (num_embeddings),
           since the 2nd embedding starts from 0 and not num_embeddings
        3. perform the 2nd embedding lookup
        4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
        5. perform the 1st embedding lookup
        6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup

        note: for the 1st embedding lookup we could have looked up only the low indices and not do
        the padding, but then we have to create a new tensor and populate it with 2 tensors that are
        spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
        complex case if it's any faster, given that seqlens are usually relatively short it's
        probably not faster or if faster not by much - but might be a good idea to measure.

        """
        if self.num_additional_embeddings == 0:
            return F.embedding(input_ids, self.weight)

        # Clone so that we don't modify the original input_ids later on
        input_ids = input_ids.clone()
        additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
        input_ids_additional_vocab = input_ids[additional_vocab_indices]
        additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)

        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
        input_ids[additional_vocab_indices] = 0
        full_vector = F.embedding(input_ids, self.weight)

        # overwrite the records with high indices
        full_vector[additional_vocab_indices] = additional_embeddings

        return full_vector

    def extra_repr(self) -> str:
        return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
            self.num_embeddings,
            self.num_additional_embeddings,
            self.embedding_dim,
            self.partially_freeze,
        )

    @classmethod
    def from_pretrained(cls, embeddings, freeze=True, **kwargs):
        raise NotImplementedError


class DecoupledLinear(nn.Linear):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
    If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        out_additional_features: int = 0,
        bias: bool = True,
        partially_freeze: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """
        out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
        partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        self.out_additional_features = out_additional_features
        self.partially_freeze = partially_freeze

        self.in_features = in_features
        self.out_features = out_features

        if partially_freeze:
            self.weight.requires_grad_(False)
            if bias:
                self.bias.requires_grad_(False)

        if out_additional_features > 0:
            self.additional_fc = nn.Linear(
                in_features=in_features,
                out_features=out_additional_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = F.linear(input, self.weight, self.bias)

        if self.out_additional_features > 0:
            additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
            output = torch.cat((output, additional_features), -1)

        return output

    def extra_repr(self) -> str:
        """Overwriting `nn.Linear.extra_repr` to include new parameters."""
        return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
            self.in_features,
            self.out_features,
            self.out_additional_features,
            self.bias is not None,
            self.partially_freeze,
        )


if __name__ == "__main__":
    emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True)
    for n, p in emb.named_parameters():
        print(n, p.requires_grad)
    idx = torch.tensor([[11, 1, 3]])
    y = emb(idx)
    loss = y.sum()
    loss.backward()
    print(emb.weight, emb.weight.grad)
    print(emb.additional_embedding, emb.additional_embedding.grad)

    lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True)
    for n, p in lin.named_parameters():
        print(n, p.requires_grad)
    x = torch.randn(12, 3)
    y = lin(x)
    loss = y.sum()
    loss.backward()
    print("Weight w and grad:", lin.weight, lin.weight.grad)
    print("bias w and grad:", lin.bias, lin.bias.grad)
    print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad)
    print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad)