try:
    import flash_attn
except:
    import subprocess

    print("Installing flash-attn...")
    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
    import flash_attn

    print("flash-attn installed.")

import os

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from threading import Thread

import gradio as gr

import spaces


MODEL_NAME_MAP = {
    "150M": "llm-jp/llm-jp-3-150m-instruct3",
    "440M": "llm-jp/llm-jp-3-440m-instruct3",
    "980M": "llm-jp/llm-jp-3-980m-instruct3",
    "1.8B": "llm-jp/llm-jp-3-1.8b-instruct3",
    "3.7B": "llm-jp/llm-jp-3-3.7b-instruct3",
    "13B": "llm-jp/llm-jp-3-13b-instruct3",
}

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
MODELS = {
    key: AutoModelForCausalLM.from_pretrained(
        repo_id, 
        quantization_config=quantization_config, 
        device_map="auto",
        attn_implementation="flash_attention_2",
    ) for key, repo_id in MODEL_NAME_MAP.items()
}
TOKENIZERS = {
    key: AutoTokenizer.from_pretrained(repo_id) for key, repo_id in MODEL_NAME_MAP.items()
}

print("Compiling model...")
for key, model in MODELS.items():
    MODELS[key] = torch.compile(model)
print("Model compiled.")


@spaces.GPU(duration=45)
def generate(
    model_name: str,
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
):
    if not message or message.strip() == "":
        return "", history

    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    tokenized_input = TOKENIZERS[model_name].apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        TOKENIZERS[model_name], timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=tokenized_input,
        streamer=streamer,
        max_new_tokens=int(max_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_k=int(top_k),
        top_p=float(top_p),
        num_beams=1,
    )
    t = Thread(target=MODELS[model_name].generate, kwargs=generate_kwargs)
    t.start()

    # 返す値を初期化
    partial_message = ""

    for new_token in streamer:
        partial_message += new_token
        new_history = history + [(message, partial_message)]
        # 入力テキストをクリアする
        yield "", new_history


def respond(
    model_name: str,
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
):

    for stream in generate(
        model_name,
        message,
        history,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
    ):
        yield (*stream,)


def retry(
    model_name: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
):
    # 最後のメッセージを削除
    last_conversation = history[-1]
    user_message = last_conversation[0]
    history = history[:-1]

    for stream in generate(
        model_name,
        user_message,
        history,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
    ):
        yield (*stream,)


def demo():
    with gr.Blocks() as ui:

        gr.Markdown(
            """\
# (unofficial) llm-jp/llm-jp-3 instruct3 モデルデモ
コレクション: https://huggingface.co/collections/llm-jp/llm-jp-3-fine-tuned-models-672c621db852a01eae939731
"""
        )

        model_name_radio = gr.Radio(label="モデル", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])

        chat_history = gr.Chatbot(value=[])

        with gr.Row():
            retry_btn = gr.Button(value="🔄 再生成", scale=1)
            clear_btn = gr.ClearButton(
                components=[chat_history], value="🗑️ 削除", scale=1,
            )

        with gr.Row():
            input_text = gr.Textbox(
                value="",
                placeholder="質問を入力してください...",
                show_label=False,
                scale=8,
            )
            start_btn = gr.Button(
                value="送信",
                variant="primary",
                scale=2,
            )

        with gr.Accordion(label="詳細設定", open=False):
            system_prompt_text = gr.Textbox(
                label="システムプロンプト",
                value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
            )
            max_new_tokens_slider = gr.Slider(
                minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"
            )
            temperature_slider = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
            )
            top_p_slider = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p",
            )
            top_k_slider = gr.Slider(
                minimum=10, maximum=500, value=100, step=10, label="Top-k"
            )

        gr.Examples(
            examples=[
                ["情けは人の為ならずとはどういう意味ですか?"],
                ["まどマギで一番可愛いのは誰?"],
            ],
            inputs=[input_text],
            cache_examples=False,
        )

        gr.on(
            triggers=[start_btn.click, input_text.submit],
            fn=respond,
            inputs=[
                model_name_radio,
                input_text,
                chat_history,
                system_prompt_text,
                max_new_tokens_slider,
                temperature_slider,
                top_p_slider,
                top_k_slider,
            ],
            outputs=[input_text, chat_history],
        )
        retry_btn.click(
            retry,
            inputs=[
                model_name_radio,
                chat_history,
                system_prompt_text,
                max_new_tokens_slider,
                temperature_slider,
                top_p_slider,
                top_k_slider,
            ],
            outputs=[input_text, chat_history],
        )

    ui.launch()


if __name__ == "__main__":
    demo()