try:
    import spaces
    def maybe_spaces_gpu(fn):
        fn = spaces.GPU(fn)
        return fn
except ModuleNotFoundError:
    print(f'Cannot import hf `spaces` with `import spaces`.')
    def maybe_spaces_gpu(fn):
        return fn

import os
import numpy as np
import argparse
import torch
import sys
import gradio as gr
from typing import Any, Iterator
from typing import Iterator, List, Optional, Tuple
import filelock
import glob
import json
import time
from gradio.routes import Request
from gradio.utils import SyncToAsyncIterator, async_iteration
from gradio.helpers import special_args
import anyio
from typing import AsyncGenerator, Callable, Literal, Union, cast

from gradio_client.documentation import document, set_documentation_group

from typing import List, Optional, Union, Dict, Tuple
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
import types

from gradio.components import Button
from gradio.events import Dependency, EventListenerMethod

from .base_engine import BaseEngine

# ! Remember to use static cache

from transformers import (
    GenerationConfig,
    GenerationMixin,
    LogitsProcessorList,
    StoppingCriteriaList,
    DisjunctiveConstraint,
    BeamSearchScorer,
    PhrasalConstraint,
    ConstrainedBeamSearchScorer,
    PreTrainedModel,
)
import numpy as np
import random
import warnings
import inspect
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
import torch
from typing import Callable, List, Optional, Union
from torch import nn
import torch.distributed as dist
import copy

from ..configs import (
    MODEL_PATH,
    DTYPE,
    DEVICE,
    STREAM_CHECK_MULTIPLE,
    STREAM_YIELD_MULTIPLE,
)




def setup_seed(seed):
    if seed == -1:
        return
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class NewGenerationMixin(GenerationMixin):
    """
    Allow generator sampling

    """

    # ! Copy from transformers.generation.utils -> GenerationMixin
    # Change sample function to sample_stream
    @torch.no_grad()
    def sample_stream(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ):
        r"""
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        <Tip warning={true}>

        In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
        For an overview of generation strategies and code examples, check the [following
        guide](../generation_strategies).

        </Tip>

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`Union[int, List[int]]`, *optional*):
                The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            output_logits (`bool`, *optional*, defaults to `False`):
                Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
                more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id
        >>> model.generation_config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
        ...     ]
        ... )
        >>> # instantiate logits processors
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )

        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )

        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
        ```"""
        # init values
        from transformers.generation.utils import (
            validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
        )
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
        output_attentions = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate
            if return_dict_in_generate is not None
            else self.generation_config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

        this_peer_finished = False  # used by synced_gpus only
        # auto-regressive generation
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_logits:
                    raw_logits += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            yield next_tokens.cpu()

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            if streamer is not None:
                streamer.put(next_tokens.cpu())
            
            next_model_inputs = {}
            if "cache_position" in model_inputs:
                next_model_inputs['cache_position'] = model_inputs['cache_position']
            try:
                model_kwargs = self._update_model_kwargs_for_generation(
                    outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, 
                    # model_inputs=model_inputs
                    model_inputs=next_model_inputs,
                )
            except Exception as e:
                # ! some transformers version don't have model_inputs in generation
                model_kwargs = self._update_model_kwargs_for_generation(
                    outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, 
                    # model_inputs=model_inputs
                    # model_inputs=next_model_inputs,
                )

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
                )

                # stop when each sentence is finished
                if unfinished_sequences.max() == 0:
                    this_peer_finished = True

            # stop if we exceed the maximum length
            if stopping_criteria(input_ids, scores):
                this_peer_finished = True

            if this_peer_finished and not synced_gpus:
                break

        if streamer is not None:
            streamer.end()

        # if return_dict_in_generate:
        #     if self.config.is_encoder_decoder:
        #         return GenerateEncoderDecoderOutput(
        #             sequences=input_ids,
        #             scores=scores,
        #             logits=raw_logits,
        #             encoder_attentions=encoder_attentions,
        #             encoder_hidden_states=encoder_hidden_states,
        #             decoder_attentions=decoder_attentions,
        #             cross_attentions=cross_attentions,
        #             decoder_hidden_states=decoder_hidden_states,
        #             past_key_values=model_kwargs.get("past_key_values"),
        #         )
        #     else:
        #         return GenerateDecoderOnlyOutput(
        #             sequences=input_ids,
        #             scores=scores,
        #             logits=raw_logits,
        #             attentions=decoder_attentions,
        #             hidden_states=decoder_hidden_states,
        #             past_key_values=model_kwargs.get("past_key_values"),
        #         )
        # else:
        #     return input_ids





BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
KEYWORDS = [x.lower() for x in KEYWORDS]

LANG_BLOCK_MESSAGE = """Unsupported language."""

KEYWORD_BLOCK_MESSAGE = "Invalid request."


def _detect_lang(text):
    # Disable language that may have safety risk
    from langdetect import detect as detect_lang
    dlang = None
    try:
        dlang = detect_lang(text)
    except Exception as e:
        if "No features in text." in str(e):
            return "en"
        else:
            return "zh"
    return dlang


def block_lang(
    message: str, 
    history: List[Tuple[str, str]] = None,
) -> str:
    # relieve history base block
    if len(BLOCK_LANGS) == 0:
        return False
    
    if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
        return True
    else:
        _lang = _detect_lang(message)
        if _lang in BLOCK_LANGS:
            # print(f'Detect blocked {_lang}: {message}')
            return True
        else:
            return False
        
def safety_check(text, history=None, ) -> Optional[str]:
    """
    Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
    This provides an additional security measure to enhance safety and compliance with local regulations.
    """
    if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
        return KEYWORD_BLOCK_MESSAGE
    
    if len(BLOCK_LANGS) > 0:
        if block_lang(text, history):
            return LANG_BLOCK_MESSAGE

    return None


def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
    if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
        return KEYWORD_BLOCK_MESSAGE
    if len(BLOCK_LANGS) > 0:
        import re
        delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
        turns = re.split(r"|".join(delimiter), text)
        turns = [t for t in turns if t.strip() != '']
        for t in turns:
            if block_lang(t):
                return LANG_BLOCK_MESSAGE
    return None


def is_check_safety():
    return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0


def safety_check_conversation(conversation) -> Optional[str]:
    """
    Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
    This provides an additional security measure to enhance safety and compliance with local regulations.
    """
    texts = [c['content'] for c in conversation]
    for text in texts:
        if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
            return KEYWORD_BLOCK_MESSAGE
        
        if len(BLOCK_LANGS) > 0:
            if block_lang(text):
                return LANG_BLOCK_MESSAGE
    return None





class TransformersEngine(BaseEngine):
    @property
    def max_position_embeddings(self) -> int:
        return self._model.config.max_position_embeddings

    @property
    def tokenizer(self):
        return self._tokenizer

    def load_model(self):
        from transformers import AutoTokenizer, AutoModelForCausalLM
        import sys
        # caution: path[0] is reserved for script path (or '' in REPL)
        # sys.path.append(CODE_PATH)
        self.model_path = model_path = MODEL_PATH
        self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
        self.device_map = DEVICE
        print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}')

        self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!"
        self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval()
        self._model.sample_old = self._model.sample
        self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
        print(self._model)
        print(f"{self.max_position_embeddings=}")

    def maybe_raise_safety(self, message, gen_index=-1):
        if is_check_safety():
            if gen_index < 0:
                message_safety = safety_check_conversation_string(message)
                if message_safety is not None:
                    raise gr.Error(message_safety)
            else:
                if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
                    message_safety = safety_check_conversation_string(message)
                    if message_safety is not None:
                        raise gr.Error(message_safety)

    # @maybe_spaces_gpu
    def generate_yield_string(self, prompt, temperature=0.7, max_tokens=1024, stop_strings: Optional[Tuple[str]] = None, **kwargs):
        
        # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
        import sys
        # self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
        self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)

        self.maybe_raise_safety(prompt)

        if temperature == 0:
            temperature = 0.0001

        try:
            
            with torch.no_grad():
                inputs = self.tokenizer(prompt, return_tensors='pt')
                # whether to print the full prompts
                retok_full_prompt = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=False)
                # print(f"retok_full_prompt:\n{retok_full_prompt}>>>>")
                begin_bos = inputs.input_ids[0][0] == self.tokenizer.bos_token_id
                print(f'begin_bos: {begin_bos}')

                num_tokens = inputs.input_ids.size(1)

                inputs = inputs.to(self._model.device)
                
                generator = self._model.generate(
                    **inputs, 
                    do_sample=True, 
                    temperature=temperature, 
                    max_new_tokens=max_tokens, 
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )

                out_tokens = []
                response = None
                for index, token in enumerate(generator):
                    out_tokens.extend(token.tolist())
                    response = self.tokenizer.decode(out_tokens, skip_special_tokens=True)
                    if "<|im_start|>assistant\n" in response:
                        response = response.split("<|im_start|>assistant\n")[-1]
                    num_tokens += 1
                    # print(f"{response}", end='\r')
                    # sys.stdout.flush()
                    self.maybe_raise_safety(response, gen_index=index)
                    yield response, num_tokens

                del generator
                if response is not None:
                    if "<|im_start|>assistant\n" in response:
                        response = response.split("<|im_start|>assistant\n")[-1]

                    self.maybe_raise_safety(response)
                    full_text = prompt + response
                    num_tokens = len(self.tokenizer.encode(full_text))
                    yield response, num_tokens
        except RuntimeError as e:
            raise gr.Error(str(e))