# coding=utf-8
# author: xusong
# time: 2022/8/23 16:06

import gradio as gr
from vocab import tokenizer_factory
from playground_examples import example_types, example_fn
from playground_util import (tokenize,
                             tokenize_pair, basic_count,
                             get_overlap_token_size, on_load)

get_window_url_params = """
    function(url_params) {
        const params = new URLSearchParams(window.location.search);
        url_params = JSON.stringify(Object.fromEntries(params));
        return url_params;
        }
    """

all_tokenizer_name = [(config.name_display, config.name_or_path) for config in tokenizer_factory.all_tokenizer_configs]

with gr.Blocks() as demo:
    # links: https://www.coderstool.com/utf8-encoding-decoding
    # 功能:输入文本,进行分词
    # 分词器:常见的分词器有集中,
    # 背景:方便分词、看词粒度、对比

    with gr.Row():
        gr.Markdown("## Input Text")
        dropdown_examples = gr.Dropdown(
            example_types,
            value="Examples",
            type="index",
            allow_custom_value=True,
            show_label=False,
            container=False,
            scale=0,
            elem_classes="example-style"
        )
    user_input = gr.Textbox(
        # value=default_user_input,
        label="Input Text",
        lines=5,
        show_label=False,
    )
    gr.Markdown("## Tokenization")

    # compress rate setting TODO: 将 这个模块调整到下面
    # with gr.Accordion("Compress Rate Setting", open=True):
    #     gr.Markdown(
    #         "Please select corpus and unit of compress rate, get more details at [github](https://github.com/xu-song/tokenizer-arena/). ")
    #     with gr.Row():
    #         compress_rate_corpus = gr.CheckboxGroup(
    #             common_corpuses,  # , "code"
    #             value=["cc100-en", "cc100-zh-Hans"],
    #             label="corpus",
    #             # info=""
    #         )
    #         compress_rate_unit = gr.Radio(
    #             common_units,
    #             value="b_tokens/g_bytes",
    #             label="unit",
    #         )
    # TODO: Token Setting
    # with gr.Accordion("Token Filter Setting", open=False):
    #     gr.Markdown(
    #         "Get total number of tokens which contain the following character)")
    #     gr.Radio(
    #         ["zh-Hans", "", "number", "space"],
    #         value="zh",
    #     )

    with gr.Row():
        with gr.Column(scale=6):
            with gr.Group():
                tokenizer_name_1 = gr.Dropdown(
                    all_tokenizer_name,
                    label="Tokenizer 1",
                    # value=default_tokenizer_name_1,
                )
                with gr.Group():
                    with gr.Row():
                        organization_1 = gr.TextArea(
                            label="Organization",
                            lines=1,
                            elem_classes="statistics",
                        )
                        stats_vocab_size_1 = gr.TextArea(
                            label="Vocab Size",
                            lines=1,
                            elem_classes="statistics"
                        )
                        # stats_zh_token_size_1 = gr.TextArea(
                        #     label="ZH char/word",
                        #     lines=1,
                        #     elem_classes="statistics",
                        # )
                        # stats_compress_rate_1 = gr.TextArea(
                        #     label="Compress Rate",
                        #     lines=1,
                        #     elem_classes="statistics",
                        # )
                        stats_overlap_token_size_1 = gr.TextArea(
                            # value=default_stats_overlap_token_size,
                            label="Overlap Tokens",
                            lines=1,
                            elem_classes="statistics"
                        )
                        # stats_3 = gr.TextArea(
                        #     label="Compress Rate",
                        #     lines=1,
                        #     elem_classes="statistics"
                        # )
        # https://www.onlinewebfonts.com/icon/418591
        gr.Image("images/VS.svg", scale=1, show_label=False,
                 show_download_button=False, container=False,
                 show_share_button=False)
        with gr.Column(scale=6):
            with gr.Group():
                tokenizer_name_2 = gr.Dropdown(
                    all_tokenizer_name,
                    label="Tokenizer 2",
                    # value=default_tokenizer_name_2
                )
                with gr.Group():
                    with gr.Row():
                        organization_2 = gr.TextArea(
                            label="Organization",
                            lines=1,
                            elem_classes="statistics",
                        )
                        stats_vocab_size_2 = gr.TextArea(
                            label="Vocab Size",
                            lines=1,
                            elem_classes="statistics"
                        )
                        # stats_zh_token_size_2 = gr.TextArea(
                        #     label="ZH char/word",  # 中文字/词
                        #     lines=1,
                        #     elem_classes="statistics",
                        # )
                        # stats_compress_rate_2 = gr.TextArea(
                        #     label="Compress Rate",
                        #     lines=1,
                        #     elem_classes="statistics"
                        # )
                        stats_filtered_token_2 = gr.TextArea(
                            label="filtered tokens",
                            lines=1,
                            elem_classes="statistics",
                            visible=False
                        )
                        stats_overlap_token_size_2 = gr.TextArea(
                            label="Overlap Tokens",
                            lines=1,
                            elem_classes="statistics"
                        )

    # TODO: 图 表 压缩率
    with gr.Row():
        # dynamic change label
        with gr.Column():
            output_text_1 = gr.Highlightedtext(
                show_legend=False,
                show_inline_category=False,
                elem_classes="space-show"
            )
        with gr.Column():
            output_text_2 = gr.Highlightedtext(
                show_legend=False,
                show_inline_category=False,
                elem_classes="space-show"
            )

    with gr.Row():
        output_table_1 = gr.Dataframe()
        output_table_2 = gr.Dataframe()

    # setting
    # compress_rate_unit.change(compress_rate_unit_change, [compress_rate_unit],
    #                             [stats_compress_rate_1, stats_compress_rate_2])

    tokenizer_name_1.change(tokenize, [user_input, tokenizer_name_1],
                            [output_text_1, output_table_1])
    tokenizer_name_1.change(basic_count, [tokenizer_name_1], [stats_vocab_size_1, organization_1])
    tokenizer_name_1.change(get_overlap_token_size, [tokenizer_name_1, tokenizer_name_2],
                            [stats_overlap_token_size_1, stats_overlap_token_size_2])
    # tokenizer_type_1.change(get_compress_rate, [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
    #                         [stats_compress_rate_1])

    # TODO: every=3
    user_input.change(tokenize_pair,
                      [user_input, tokenizer_name_1, tokenizer_name_2],
                      [output_text_1, output_table_1, output_text_2, output_table_2], show_api=False)  # , pass_request=1

    tokenizer_name_2.change(tokenize, [user_input, tokenizer_name_2],
                            [output_text_2, output_table_2], show_api=False)
    tokenizer_name_2.change(basic_count, [tokenizer_name_2], [stats_vocab_size_2, organization_2], show_api=False)
    tokenizer_name_2.change(get_overlap_token_size, [tokenizer_name_1, tokenizer_name_2],
                            [stats_overlap_token_size_1, stats_overlap_token_size_2], show_api=False)
    # tokenizer_type_2.change(get_compress_rate,
    #                         [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
    #                         [stats_compress_rate_2])
    #
    # compress_rate_unit.change(get_compress_rate,
    #                           [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
    #                           [stats_compress_rate_1])
    # compress_rate_unit.change(get_compress_rate,
    #                           [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
    #                           [stats_compress_rate_2])
    # compress_rate_corpus.change(get_compress_rate,
    #                             [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
    #                             [stats_compress_rate_1])
    # compress_rate_corpus.change(get_compress_rate,
    #                             [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
    #                             [stats_compress_rate_2])

    dropdown_examples.change(
        example_fn,
        dropdown_examples,
        [user_input, tokenizer_name_1, tokenizer_name_2],
        show_api=False
    )

    demo.load(
        fn=on_load,
        inputs=[user_input],  # 这里只需要传个空object即可。
        outputs=[user_input, tokenizer_name_1, tokenizer_name_2],
        js=get_window_url_params,
        show_api=False
    )

if __name__ == "__main__":
    # demo.queue(max_size=20).launch()
    demo.launch()
    # demo.launch(share=True)