import copy
import os
import sys

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)

import contextlib

import torch.utils.checkpoint
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image

from .modeling_perceive_sampler import BertConfig, BertLMHeadModel
from .modeling_vit import *
from .modeling_InternLM import *
from .modeling_utils import *

from transformers.utils import logging
logger = logging.get_logger(__name__)


class InternLMXComposerForCausalLM(PreTrainedModel):
    config_class = InternLMXComposerConfig
    _auto_class = "AutoModelForCausalLM"

    gen_config = dict(
        num_beams=5,
        do_sample=False,
        min_length=1,
        repetition_penalty=1.5,
        length_penalty=1.0,
        temperature=1.0,
        max_new_tokens=200,
    )

    def __init__(self, config):
        super().__init__(config)

        print('Init VIT ... ', end='')
        self.visual_encoder = create_eva_vit_g()
        self.ln_vision = LayerNorm(self.visual_encoder.num_features)
        print('Done')

        print('Init Perceive Sampler ... ', end='')
        with all_logging_disabled():
            self.Qformer, self.query_tokens = self.init_qformer(
                config.num_query_token, self.visual_encoder.num_features)
            self.Qformer.bert.embeddings.word_embeddings = None
            self.Qformer.bert.embeddings.position_embeddings = None
            for layer in self.Qformer.bert.encoder.layer:
                layer.output = None
                layer.intermediate = None
            self.Qformer.cls = None
        print('Done')

        print('Init InternLM ... ', end='')
        self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
        self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
        self.flag_image_start.requires_grad = False
        self.flag_image_end.requires_grad = False

        internlm_lora = config.internlm_lora
        self.internlm_lora = internlm_lora
        setattr(InternLMForCausalLM, 'lora_cfg', internlm_lora)

        if int(torch.__version__[0]) == 1:
            self.internlm_model = InternLMForCausalLM._from_config(config).to(
                torch.float16)
        else:
            assert int(torch.__version__[0]) == 2
            # speed up init llm
            with torch.device('meta'):
                self.internlm_model = InternLMForCausalLM._from_config(config)
            self.internlm_model.to_empty(device='cpu').to(torch.float16)
            self.internlm_model.to(config.device)
        for n, m in self.internlm_model.named_modules():
            if 'lora' in n:
                m.float()

        self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
                                    self.internlm_model.config.hidden_size)
        print('Done')

        self.vis_processor = transforms.Compose([
            transforms.Resize((224, 224),
                              interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                 (0.26862954, 0.26130258, 0.27577711)),
        ])

        self.tokenizer = None

    @property
    def eoh(self):
        #return self.tokenizer.decode(torch.Tensor([103027]),
        #                             skip_special_tokens=True)
        return '<TOKENS_UNUSED_0>'

    @property
    def eoa(self):
        #return self.tokenizer.decode(torch.Tensor([103028]),
        #                             skip_special_tokens=True)
        return '<TOKENS_UNUSED_1>'

    def maybe_autocast(self, dtype=torch.float16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    @classmethod
    def init_qformer(cls,
                     num_query_token,
                     vision_width,
                     cross_attention_freq=2,
                     pretrain=True):
        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
        encoder_config.encoder_width = vision_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        # if pretrain:
        #     Qformer = BertLMHeadModel.from_pretrained("bert-base-uncased",
        #                                               config=encoder_config)
        # else:
        Qformer = BertLMHeadModel(config=encoder_config)
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size))
        query_tokens.data.normal_(mean=0.0,
                                  std=encoder_config.initializer_range)
        return Qformer, query_tokens

    def encode_img(self, image):
        if image is None:
            return None
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
            image = self.vis_processor(image).unsqueeze(0).to(self.device)
        else:
            assert isinstance(image, torch.Tensor)
        device = image.device
        with self.maybe_autocast():
            image_embeds = self.ln_vision(
                self.visual_encoder(image)).to(device)
            image_atts = torch.ones(image_embeds.size()[:-1],
                                    dtype=torch.long).to(device)
            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
                                                    -1)
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )
            inputs_internlm = self.internlm_proj(query_output.last_hidden_state)
            inputs_internlm = torch.cat([
                self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1),
                inputs_internlm,
                self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1)
            ],
                                      dim=1)
        return inputs_internlm

    def encode_text(self, text, add_special_tokens=False):
        text_token_ids = self.tokenizer(
            text,
            return_tensors='pt',
            add_special_tokens=add_special_tokens,
        ).input_ids.to(self.device)
        text_embeds = self.internlm_model.model.embed_tokens(text_token_ids)
        return text_embeds

    def decode_text(self, out_embeds):
        out_text = self.tokenizer.batch_decode(out_embeds,
                                               skip_special_tokens=True)[0]
        out_text = out_text.split(self.eoa)[0]
        return out_text

    def wrap_text(self, user_text, bot_text='', add_special=True):
        if add_special:
            eoh = self.eoh
        else:
            eoh = ''
        text = f' <|User|>:{user_text} \n{eoh} <|Bot|>:{bot_text}'
        return text

    def get_gen_args(self, **kwargs):
        new_kargs = copy.deepcopy(self.gen_config)
        new_kargs.update(kwargs)
        return new_kargs
    
    def forward(self, **kwargs):
        return self.internlm_model(**kwargs)

    def generate(self, text, image=None, **kwargs):
        text_embeds = self.encode_text(text)
        img_embeds = self.encode_img(image)
        prompt_embeds = self.wrap_prompt(text_embeds, img_embeds)
        out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
                                                **self.get_gen_args(**kwargs))
        out_text = self.decode_text(out_embeds)
        return out_text

    def chat(self, text, image=None, history=None, **kwargs):
        text_embeds = self.encode_text(text)
        img_embeds = self.encode_img(image)
        prompt_embeds = self.wrap_prompt(text_embeds,
                                         img_embeds,
                                         history=history)
        out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds,
                                                **self.get_gen_args(**kwargs))
        out_text = self.decode_text(out_embeds)

        # trunc at eoh and eoa
        clean_out_text_token_ids = self.tokenizer(
            out_text, return_tensors='pt').input_ids.to(self.device)
        clean_out_text_embeds = self.internlm_model.model.embed_tokens(
            clean_out_text_token_ids)
        clean_prompt_embeds = self.wrap_prompt(text_embeds,
                                               img_embeds,
                                               add_special=False)
        cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
                                dim=1)
        if history is None:
            history = []
        history.append(cur_history)
        return out_text, history

    def wrap_prompt(self,
                    text_embeds,
                    img_embeds=None,
                    history=None,
                    add_special=True):
        if add_special:
            prompt_segs = [' <|User|>:', f'\n{self.eoh} <|Bot|>:']
        else:
            prompt_segs = [' <|User|>:', ' <|Bot|>:']  # used in wrap history
        prompt_seg_embeds = []
        for i, seg in enumerate(prompt_segs):
            if history is not None:
                add_special_tokens = False
            else:
                add_special_tokens = i == 0
            seg_embeds = self.encode_text(
                seg, add_special_tokens=add_special_tokens)
            prompt_seg_embeds.append(seg_embeds)
        if img_embeds is None:
            img_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
                                               text_embeds.size(-1))
        prompt_seg_embeds = [
            prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1]
        ]
        prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
        if history is not None:
            prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
        return prompt_embeds