import io
import logging
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import MSELoss
from .llm.llama_xformer import LlamaForCausalLM

from petrel_client.client import Client
from torch.cuda.amp import autocast as autocast

from .vision_encoder import pretrain_internvideo2_giant_patch14_224_clean, build_vit, interpolate_pos_embed_internvideo2_new
from .bridge import build_qformer, build_causal_qformer

logger = logging.getLogger(__name__)

from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor
from transformers import AutoConfig, PreTrainedModel


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def freeze_module(module):
    for _, param in module.named_parameters():
        param.requires_grad = False
    module = module.eval()
    module.train = disabled_train
    return module


class LLMConfig(AutoConfig):
    model_type = ""


class BaseMLLM(PreTrainedModel):
    config_class = LLMConfig
    def __init__(self, config):
        # m_config = LLMConfig.from_pretrained('/mnt/petrelfs/share_data/likunchang/model/llm/internlm2-chat-20b', trust_remote_code=True)
        # super().__init__(config)
        self.model_config = config.model_config
        config.model_config = None
        super().__init__(config)
        self.build_vision_encoder()
        self.build_llm()
        self.build_bridge()
        self.build_loss()
        self.load_pretrained_weights()
        # NOTE place it after freeze llm
        for n, p in self.named_parameters():
            if p.requires_grad:
                logger.info(f'{n} requires_grad')
        
    
    def build_vision_encoder(self):
        # load pretrained internvideo2-1b here, simplified as it receives no args
        # note that we haven't load the internvideo pretrained version
        if 'internvideo2' in self.model_config.vision_encoder.name.lower():
            encoder_name = self.model_config.vision_encoder.name
            logger.info(f"Build vision_encoder: {encoder_name}")
            if encoder_name == 'internvideo2-1B':
                self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
            else:
                raise ValueError(f"Not implemented: {encoder_name}")
        elif 'vit' in self.model_config.vision_encoder.name.lower():
            self.vision_encoder = build_vit(self.model_config)
        else:
            raise NotImplementedError(self.model_config.vision_encoder.name)

        if self.model_config.vision_encoder.vit_add_ln:
            self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
        else:
            self.vision_layernorm = nn.Identity()

        self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)

        if self.freeze_vision_encoder:
            logger.info("freeze vision encoder")
            freeze_module(self.vision_encoder)
            freeze_module(self.vision_layernorm)


    def build_bridge(self):
        # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
        self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
        # LM to ViT: 6656 -> 1792
        self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
        
        if 'qformer' in self.model_config.bridge.name.lower():
            from transformers import BertTokenizer
            self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left", local_files_only=True)
            self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
            self.qformer_tokenizer.padding_side = "left"
            if self.model_config.bridge.name == 'qformer':
                self.qformer, self.query_tokens = build_qformer(
                        self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
                        qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
                        qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
                        qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
                )
            elif self.model_config.bridge.name == 'causal_qformer':
                self.qformer, self.query_tokens = build_causal_qformer(
                        self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
                        qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
                        qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob
                )
            self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
            self.qformer.cls = None
            self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
            if self.model_config.bridge.extra_num_query_token > 0:
                logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
                self.extra_query_tokens = nn.Parameter(
                    torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
                )
            
            self.freeze_bridge = self.model_config.get("freeze_bridge", False)
            if self.freeze_bridge:
                logger.info("freeze bridge")
                freeze_module(self.qformer)
                self.query_tokens.requires_grad = False

    def build_llm(self):
        self.lm_name = self.model_config.llm.name
        if self.model_config.llm.name == "vicuna1.5_7b":
            self.lm = LlamaForCausalLM.from_pretrained(self.model_config.llm.pretrained_llm_path)
            self.lm.gradient_checkpointing = self.model_config.llm.get("use_llama_gradient_checkpointing", True)
        elif self.model_config.llm.name == 'mistral_7b':
            from transformers import AutoModelForCausalLM
            self.lm = AutoModelForCausalLM.from_pretrained(
                self.model_config.llm.pretrained_llm_path,
                torch_dtype=torch.bfloat16,
                # attn_implementation="flash_attention_2",
            )
        elif self.model_config.llm.name == 'internlm_20b':
            from transformers import AutoModelForCausalLM
            self.lm = AutoModelForCausalLM.from_pretrained(
                self.model_config.llm.pretrained_llm_path,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
            )
            self.lm.gradient_checkpointing = True
            self.lm._set_gradient_checkpointing()
        elif self.model_config.llm.name == 'internlm2_5_7b':
            from transformers import AutoModelForCausalLM
            self.lm = AutoModelForCausalLM.from_pretrained(
                self.model_config.llm.pretrained_llm_path,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                local_files_only=True,
            )
        else:
            raise NotImplementedError(self.model_config.llm.name)

        self.freeze_llm = self.model_config.get("freeze_llm", True)
        logger.info(f'freeze_llm: {self.freeze_llm}')
        if self.freeze_llm:
            logger.info("freeze llm")
            freeze_module(self.lm)
        
        if self.model_config.llm.use_lora:
            self.use_lora = True
            from peft import get_peft_model, LoraConfig, TaskType
            logger.info("Use lora")
            if self.model_config.llm.name == 'internlm_20b':
                peft_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM, inference_mode=False, 
                    r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
                    target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
                )
            else:
                peft_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM, inference_mode=False, 
                    r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
                    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                                    "gate_proj", "up_proj", "down_proj", "lm_head"]
                )
                
            self.lm = get_peft_model(self.lm, peft_config)
            self.lm.enable_input_require_grads()
            self.lm.print_trainable_parameters()
        else:
            self.use_lora = False


    def build_loss(self):
        self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False)
        if self.use_vision_regression_loss:
            self.image_loss_fct = MSELoss()
        
    @property
    def dtype(self):
        return self.lm.dtype


    @property
    def device(self):
        return self.lm.device