from typing import List, Optional, Tuple

import torch
from PIL import Image
from torch import nn
from transformers import (
    CLIPVisionModel,
    GenerationMixin,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from .catty import split_image_with_catty
from .configuration_points_chat import POINTSChatConfig
from .dynamic_high_resolution import split_image
from .modeling_llama import CustomLlamaForCausalLM


class POINTSChatModel(PreTrainedModel, GenerationMixin):
    config_class = POINTSChatConfig
    _no_split_modules = ["CLIPVisionModel", "LLamaDecoderLayer"]
    """Chat model for POINTS.
    
    Official implementation of the paper "POINTS: Improving Your Vision-language Model with Affordable Strategies"  # noqa: E501
    paper: https://huggingface.co/papers/2409.04828

    Args:
        config (PretrainedConfig): The model config.
    """

    def __init__(self, config: POINTSChatConfig) -> None:
        super().__init__(config)
        self.general_vit = CLIPVisionModel(config.vision_config)
        self.ocr_vit = CLIPVisionModel(config.vision_config)
        self.llm = CustomLlamaForCausalLM(config.llm_config)
        self.vision_projector = nn.Sequential(
            nn.Linear(config.vision_config.hidden_size *
                      4, config.llm_config.hidden_size),
            nn.GELU(),
            nn.Linear(config.llm_config.hidden_size,
                      config.llm_config.hidden_size)

        )

    def apply_chat_template(self, prompt: str, image_num: int) -> str:
        """Apply the Yi-1.5-Chat template to the prompt.

        Args:
            prompt (str): The prompt to apply the template to.
            image_num (int): The number of the image in the prompt.
        Returns:
            str: The prompt with the template applied.
        """
        image_tokens = ('<|endoftext|>' * 144) * image_num
        prompt = f'<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n<|im_start|>assistant\n'  # noqa: E501
        return prompt

    def pixel_shuffle(self, feature_map: torch.Tensor,
                      scale_factor: float = 0.5) -> torch.Tensor:
        """Implementation of pixel shuffle.

        Merge several patches into a single patch by concatenating
        them across the channel dimension. Therefore, we can reduce 
        the image sequence length. In POINTS, we merge 2x2 adjacent
        patches into a single patch.

        Args:
            feature_map (torch.Tensor): The feature map to be pixel 
                shuffled.
            scale_factor (float, optional): The scale factor for the    
        """

        # taken from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5/blob/main/modeling_internvl_chat.py#L187 # noqa
        n, w, h, c = feature_map.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        feature_map = feature_map.view(
            n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        feature_map = feature_map.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        feature_map = feature_map.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        feature_map = feature_map.permute(0, 2, 1, 3).contiguous()
        return feature_map

    def extract_image_features(self, images: torch.Tensor,
                               vision_encoder: str = 'general_vit') -> torch.Tensor:  # noqa: E501
        """Extract the image features from the vision encoder.

        Args:
            images (torch.Tensor): The images to extract the features from.
            vision_encoder (str, optional): The vision encoder to use.
                Defaults to 'general_vit'.

        Returns:
            torch.Tensor: The extracted image features.
        """
        if vision_encoder == 'general_vit':
            image_features = self.general_vit(
                images, output_hidden_states=True
            )
        else:
            image_features = self.ocr_vit(
                images, output_hidden_states=True
            )
        image_features = image_features.hidden_states[-2]
        image_features = image_features[:, 1:]
        image_features = image_features.reshape(-1, 24, 24, 1024)
        image_features = self.pixel_shuffle(image_features, 0.5)
        image_features = image_features.view(-1, 144, 4096)
        image_features = self.vision_projector(image_features)
        return image_features

    def get_pos_mapping(self, pos: List[list]) -> Tuple[dict, int]:
        """Get the position mapping for the images.

        Args:
            pos (List[list]): The position of the images in the prompt.

        Returns:
            Tuple[dict, int]: The position mapping and the 
            total number of images.
        """
        mapping = {}
        total_images = 0
        for i, (start, end) in enumerate(pos):
            num_image = int((end - start) / 144)
            mapping[i] = num_image
            total_images += num_image
        return mapping, total_images

    @torch.no_grad()
    def chat(self, pixel_values: Image, prompt: str,
             tokenizer: PreTrainedTokenizer,
             image_processor, catty: bool = True,
             generation_config: dict = None,
             max_splits: int = 8) -> str:
        """Generate a response to the input prompt.

        Args:
            pixel_values (Image): The input image.
            prompt (str): The input prompt.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
            image_processor: The image processor to use.
            catty (bool, optional): Whether to use catty. Defaults to True.
            generation_config (dict, optional): The generation config. 
                Defaults to None.
            max_splits (int, optional): The maximum number of splits. 
                Defaults to 8.
        Returns:
            str: The generated response.
        """
        if catty:
            cropped_images = split_image_with_catty(pixel_values,
                                                    do_resize=True,
                                                    max_crop_slices=max_splits)
        else:
            cropped_images = split_image(pixel_values, max_splits=max_splits)
        prompt = self.apply_chat_template(prompt, len(cropped_images))
        cropped_images = image_processor.preprocess(
            cropped_images, return_tensors='pt')['pixel_values']
        cropped_images = cropped_images.to(self.device)
        # extract features with general_vit
        general_vit_features = self.extract_image_features(
            cropped_images, vision_encoder='general_vit')
        # extract features with ocr_vit
        ocr_vit_features = self.extract_image_features(
            cropped_images, vision_encoder='ocr_vit')
        image_features = 0.5 * general_vit_features + 0.5 * ocr_vit_features
        model_inputs = tokenizer(prompt, return_tensors='pt')
        input_ids = model_inputs['input_ids'].to(self.device)
        attention_mask = model_inputs['attention_mask'].to(self.device)
        # stop token
        eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        # image token
        image_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
        generation_config.update(
            {
                'eos_token_id': eos_token_id,
            }
        )
        outputs = self.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            image_features=[image_features],
            image_token_id=image_token_id,
            **generation_config
        )
        response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        return response

    def generate(self,
                 input_ids: torch.LongTensor,
                 attention_mask: torch.LongTensor,
                 image_features: List[torch.Tensor],
                 image_token_id: int,
                 generation_config: Optional[dict] = None,
                 output_hidden_states: Optional[bool] = None,
                 return_dict: Optional[bool] = None,
                 **generate_kwargs) -> torch.LongTensor:
        input_embeddings = self.llm.lm.embed_in(input_ids)
        batch_size = input_ids.shape[0]
        assert len(image_features) == batch_size
        for i in range(batch_size):
            special_pos = input_ids[i] == image_token_id
            pos = (special_pos[:-1] != special_pos[1:]).nonzero() + 1
            if pos.shape[0] % 2 != 0:
                # when the sequence is <image><caption>
                # we need to add a dummy token
                pos = torch.cat([torch.tensor([[0]]).to(pos.device), pos])
            pos = pos.reshape(-1, 2).tolist()
            pos_mapping, total_images = self.get_pos_mapping(pos)
            assert total_images == len(image_features[i])
            img_offset = 0
            for j, (start, end) in enumerate(pos):
                num_images = pos_mapping[j]
                input_embeddings[i, start:end] = torch.cat(
                    [image_features[i][img_offset+k]
                        for k in range(num_images)],
                    dim=0
                )
                img_offset += num_images
        outputs = self.llm.generate(
            inputs_embeds=input_embeddings,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=True,
            **generate_kwargs
        )
        return outputs