from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss

from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from torch import nn
import torch.nn.functional as F
from .configuration_aimv2 import MonoConfig
from .modeling_aimv2 import AIMv2Model, PixelShuffleConnector
from transformers.generation import GenerationMixin

"""

Simple arch of Mono, used for pretrain vision encoder.

"""


@dataclass
class MonoCausalLMOutputWithPast(ModelOutput):

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class MonoPretrainedModel(PreTrainedModel):
    config_class = MonoConfig
    base_model_prefix = "mono"
    # main_input_name = "pixel_values"
    _supports_sdpa = True
    _supports_flash_attn_2 = True
    _supports_cache_class = True
    supports_gradient_checkpointing = True


# class MonoForConditionalGeneration(MonoPretrainedModel, Qwen2ForCausalLM):
class MonoForConditionalGeneration(MonoPretrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: MonoConfig):
        # super().__init__(config)
        MonoPretrainedModel.__init__(self, config)
        # super(Qwen2ForCausalLM, self).__init__(config)

        self.vision_tower = AIMv2Model(config=config.vision_config)
        self._attn_implementation = config._attn_implementation

        self._build_image_projection_layers(config)

        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.pad_token_id = config.pad_token_id
        print(f"==> pad_token_id: {self.pad_token_id}")
        self.post_init()

    def _build_image_projection_layers(self, config):
        image_dim_out = config.vision_config.hidden_size
        dim_projection = config.hidden_size
        # self.mm_projector = nn.Linear(image_dim_out, dim_projection)
        self.mm_projector = PixelShuffleConnector(image_dim_out, dim_projection)
        print(f"==> build mm_projector: {image_dim_out} -> {dim_projection}")

    def get_vision_tower(self):
        return self.vision_tower

    def get_input_embeddings(self):
        return self.model.get_input_embeddings()

    def resize_token_embeddings(
        self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
    ) -> nn.Embedding:
        model_embeds = self.model.resize_token_embeddings(
            new_num_tokens, pad_to_multiple_of
        )
        # update vocab size
        self.config.text_config.vocab_size = model_embeds.num_embeddings
        self.config.vocab_size = model_embeds.num_embeddings
        self.vocab_size = model_embeds.num_embeddings
        return model_embeds

    def _encode_image(self, pixel_values):
        # print(f"pixel_values: {pixel_values}")
        batch_size, C, H, W = pixel_values.shape
        x = self.vision_tower(pixel_values, output_hidden_states=True)
        x = x.hidden_states[-2]
        # print(x)
        x = self.mm_projector(x)
        # print(f"image features: {x}")
        return x

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position=None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        image_features = None
        if inputs_embeds is None:
            if pixel_values is not None:
                # (batch_size, num_image_tokens, hidden_size)
                image_features = self._encode_image(pixel_values)

            if input_ids is not None:
                inputs_embeds, attention_mask, labels = (
                    self._get_input_embeds_with_image(input_ids, image_features, labels)
                )

        # print(f'before inputs_embeds: {inputs_embeds.shape}')
        # print(f'before labels: {labels.shape}')

        # padding all to normal sequence length only train
        # if labels is not None:
        #     input_length = inputs_embeds.shape[1]
        #     label_length = labels.shape[1]

        #     if labels is not None:
        #         labels = F.pad(labels, (input_length, 0), value=-100)

        #     if inputs_embeds is not None:
        #         # append embeds and attn_mask to labels length
        #         padding = torch.zeros(
        #             inputs_embeds.shape[0],
        #             label_length,
        #             inputs_embeds.shape[2],
        #             dtype=inputs_embeds.dtype,
        #             device=inputs_embeds.device,
        #         )
        #         inputs_embeds = torch.cat([inputs_embeds, padding], dim=1)
        #         attention_mask = attention_mask.to(inputs_embeds.dtype)
        #         attention_mask = F.pad(attention_mask, (0, label_length), value=0)

        # if position_ids is None:
        #     position_ids = torch.arange(
        #         input_length + label_length, device=inputs_embeds.device
        #     )
        #     position_ids = position_ids.unsqueeze(0).expand(
        #         inputs_embeds.shape[0], -1
        #     )
        # position_ids[input_length:] = 0

        # print(f"position_ids {position_ids}")
        # print(f"labels {labels.shape}")
        # print(f"labels {labels}")
        # print(f"inputs_embeds {inputs_embeds.shape}")
        # print(f"inputs_embeds {inputs_embeds}")
        # print(f"attention_mask {attention_mask.shape}")
        # print(f"attention_mask {attention_mask}")

        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Upcast to float if we need to compute the loss to avoid potential precision issues
            logits = logits.float()
            labels = labels.to(logits.device)
            # Shift so that tokens < n predict n
            if attention_mask is not None:
                # we use the input attention mask to shift the logits and labels, because it is 2D.
                # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
                shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
                    logits.device
                )
                shift_logits = logits[..., :-1, :][
                    shift_attention_mask != 0
                ].contiguous()
                # print(f"shift_logits: {shift_logits.shape}")
                shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
                # print(f"shift_labels: {shift_labels.shape}")
            else:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return MonoCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def _get_input_embeds_with_image(self, input_ids, image_features, labels=None):
        # 1. replace image token with features; 2. replace -100 in input_ids into zeroes
        # 3. handling right attention_mask
        # not complicated, you can understand.
        batch_size = input_ids.size(0)
        processed_embeds = []
        processed_masks = []
        labels_ignored_im = []

        max_seq_len = 0
        for idx in range(batch_size):
            seq = input_ids[idx]
            im_pos = (seq == -200).nonzero(as_tuple=True)[0]

            if im_pos.numel() > 0:
                im_pos = im_pos.item()
                before = seq[:im_pos]
                after = seq[im_pos + 1 :]
                # Exclude -100 tokens (maybe, input_ids padding with -100 intentionly)
                before = before[before != -100]
                after = after[after != -100]
                # Get embeddings for before and after
                before_embed = self.get_input_embeddings()(before)
                after_embed = self.get_input_embeddings()(after)
                # Concatenate before, image features, and after
                seq_embed = torch.cat(
                    [before_embed, image_features[idx], after_embed], dim=0
                )
                new_seq_len = seq_embed.size(0)

                # if labels not None, change image token into -100, keep image tokens length
                if labels is not None:
                    image_token_ignore = torch.full(
                        (image_features[idx].shape[0],),
                        -100,
                        dtype=torch.long,
                        device=labels.device,
                    )
                    labels_ignored_im.append(
                        torch.cat(
                            (
                                labels[idx][:im_pos],
                                image_token_ignore,
                                labels[idx][im_pos + 1 :],
                            ),
                            dim=0,
                        )
                    )

            else:
                # Exclude -100 tokens
                valid_tokens = seq[seq != -100]
                seq_embed = self.get_input_embeddings()(valid_tokens)
                new_seq_len = seq_embed.size(0)

            # Update the maximum sequence length
            if new_seq_len > max_seq_len:
                max_seq_len = new_seq_len

            processed_embeds.append(seq_embed)
            attn_mask = torch.ones(new_seq_len, dtype=torch.bool, device=seq.device)
            processed_masks.append(attn_mask)

        # rest embedding is 0, rest mask is False, just padding it
        inputs_embeds = torch.nn.utils.rnn.pad_sequence(
            processed_embeds, batch_first=True, padding_value=0.0
        )
        attn_masks = torch.nn.utils.rnn.pad_sequence(
            processed_masks, batch_first=True, padding_value=0
        )
        if labels is not None:
            labels_ignored_im = torch.stack(labels_ignored_im, dim=0)
            return inputs_embeds, attn_masks, labels_ignored_im
        return inputs_embeds, attn_masks, None

    @torch.no_grad()
    def generate(self, input_ids, pixel_values=None, **kwargs):
        # print(input_ids)
        # print(f"pixel_values {pixel_values}")
        if pixel_values is not None:
            image_features = self._encode_image(pixel_values)
            # print(f"image_features {image_features}")
            inputs_embeds, attention_mask, _ = self._get_input_embeds_with_image(
                input_ids, image_features
            )
        else:
            if input_ids is not None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
                attention_mask = torch.ones(
                    inputs_embeds.size(0),
                    inputs_embeds.size(1),
                    dtype=torch.bool,
                    device=inputs_embeds.device,
                )

        # print(f"inputs_embeds: {inputs_embeds}")
        return super().generate(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            **kwargs,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        attention_mask=None,
        **kwargs,
    ):
        # cut input_ids if past_key_values is used
        # if past_key_values is not None:
        #     past_length = past_key_values[0][0].shape[2]

        #     # Some generation methods already pass only the last input ID
        #     if input_ids.shape[1] > past_length:
        #         input_ids = input_ids[:, -1:]
        #     elif input_ids.shape[1] == 1:
        #         pass
        #     else:
        #         # Default to old behavior: keep only final ID
        #         input_ids = input_ids[:, -1:]

        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )
        return model_inputs

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self.model.shift_tokens_right(labels)

    def _reorder_cache(self, *args, **kwargs):
        return self.model._reorder_cache(*args, **kwargs)