import gradio as gr
import json
import copy
import pandas as pd
from vocab import tokenizer_factory
from character_util import iter_vocab
from utils.log_util import logger
from utils.i18n_util import get_lang
from playground_examples import default_tokenizer_name_1, default_tokenizer_name_2, default_user_input
from functools import lru_cache


@lru_cache
def _tokenize(
        text: str,
        tokenizer_name: str,
        color_num: int = 5,
        add_special_token: bool = False
):
    logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_name}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
    if add_special_token:
        encoding = tokenizer.encode(text, add_special_tokens=True)
    else:
        encoding = tokenizer.encode(text, add_special_tokens=False)

    table = []

    for idx, token_id in enumerate(encoding):
        decoded_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decoded_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.error(f"{idx}: decode_error: " + json.dumps(  # gpt_35_turbo 经常有token会decode error,这里用来记录一下
                    {"tokenizer_type": tokenizer_name, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            # json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            # json_dumps = json.dumps(token_str)
        else:
            logger.error(f"{idx}: wrong type for token {token_id} {type(token)} " + json.dumps(
                {"text": text, "tokenizer_type": tokenizer_name}, ensure_ascii=False))
            token_str = token
            token_bytes = token
            # continue

        # ⭐
        # TODO: gpt3.5_turbo错误: 只有id和text是对的,token和 utf8都是错的。说明 convert_ids_to_tokens 出错了。
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decoded_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "UTF8 Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"tokenizer_type={tokenizer_name}, Tokens={table[:4]}")
    return pos_tokens, len(encoding), table_df


def tokenize(
        text: str,
        tokenizer_name: str,
        color_num: int = 5,
        add_special_token: bool = False
):
    """ tokenize wrapper
    As gr.Update would be overwritten after passing to frontend, we apply lru_cache in _tokenize.
    """
    pos_tokens, num_tokens, table_df = _tokenize(text, tokenizer_name, color_num, add_special_token)
    return gr.update(value=pos_tokens, label=f"Tokens: {num_tokens}"), table_df


def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    """
    input_text.change
    """
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


@lru_cache
def basic_count(tokenizer_name):
    stats = iter_vocab(tokenizer_name)
    return stats['vocab_size'], f'{stats["organization"]}'
    # return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'


# def get_compress_rate(tokenizer_name, all_corpus, unit):
#     tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
#     compress_rate_stats = tokenize_corpus(tokenizer, all_corpus)
#     compress_rate = unit_convertor(compress_rate_stats, unit)
#     return compress_rate


@lru_cache
def get_overlap_token_size(tokenizer_name_1, tokenizer_name_2):
    tokenizer1 = tokenizer_factory.get_tokenizer(tokenizer_name_1)
    tokenizer2 = tokenizer_factory.get_tokenizer(tokenizer_name_2)

    vocab_set_1 = tokenizer1.get_vocab().keys()
    vocab_set_2 = tokenizer2.get_vocab().keys()

    token1 = next(iter(vocab_set_1))
    token2 = next(iter(vocab_set_2))
    if type(token1) != type(token2):  # bytes  str
        if isinstance(token1, str):
            vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
        if isinstance(token2, str):
            vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])

    overlap_tokens = vocab_set_1 & vocab_set_2
    overlap_token_size = len(overlap_tokens)
    logger.info(
        f"{overlap_token_size} OverlapTokens of {tokenizer_name_1} {tokenizer_name_2}: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size


def on_load(url_params, request: gr.Request):
    """
    onLoad
    """
    text = None
    tokenizer_type_1 = None
    tokenizer_type_2 = None
    try:
        url_params = json.loads(url_params)
    except:
        url_params = {}
    if request:
        lang, lang_code = get_lang(request)
        logger.info(str(request.headers))
        client_ip = request.client.host
        # local_ip = socket.gethostbyname(socket.gethostbyname(""))
        # headers = request.kwargs['headers']
        # if headers and 'x-forwarded-for' in headers:
        #     x_forwarded_for = headers['x-forwarded-for']
        #     client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
        # if "referer" in request.headers:   # not work for huggingface-space
        #     url_params = parse_qs(urlparse(request.headers["referer"]).query)
        #     url_params = {k: v[0] for k, v in url_params.items() if len(v) > 0}
        tokenizer_type_1 = url_params.get("tokenizer1", default_tokenizer_name_1)
        tokenizer_type_2 = url_params.get("tokenizer2", default_tokenizer_name_2)
        text = url_params.get("text", default_user_input)
        logger.info(f"client_ip: {client_ip}; lang: {lang} params: {url_params}")
    return text, tokenizer_type_1, tokenizer_type_2


# def compress_rate_unit_change(unit):
#     return gr.update(label=f"Compress Rate: {unit}"), gr.update(label=f"Compress Rate: {unit}"),


def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(get_overlap_token_size("gpt-35-turbo", "gpt-4"))
    # print(basic_count("internlm_chat_7b"))