try:
    import flash_attn
except:
    import subprocess

    print("Installing flash-attn...")
    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
    import flash_attn

    print("flash-attn installed.")

import os
import uuid
import requests

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from threading import Thread

import gradio as gr
from dotenv import load_dotenv

import spaces


load_dotenv()

MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"
PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL")
assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, quantization_config=quantization_config, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print("Compiling model...")
model = torch.compile(model)
print("Model compiled.")


def send_report(
    type: str,
    data: dict,
):
    print(f"Sending report: {data}")
    try:
        res = requests.post(PREFERENCE_API_URL, json={"type": type, **data})
        print(f"Report sent: {res.json()}")
    except Exception as e:
        print(f"Failed to send report: {e}")


def send_reply(
    reply_id: str,
    parent_id: str,
    role: str,
    body: str,
):

    send_report(
        "conversation",
        {
            "reply_id": reply_id,
            "parent_id": parent_id,
            "role": role,
            "body": body,
        },
    )


def send_score(
    reply_id: str,
    score: int,
):
    # print(f"Score: {score}, reply_id: {reply_id}")
    send_report(
        "score",
        {
            "reply_id": reply_id,
            "score": score,
        },
    )


def generate_unique_id():
    return str(uuid.uuid4())


@spaces.GPU(duration=45)
def generate(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
):
    if not message or message.strip() == "":
        return "", history

    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    tokenized_input = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        input_ids=tokenized_input,
        streamer=streamer,
        max_new_tokens=int(max_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_k=int(top_k),
        top_p=float(top_p),
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # 返す値を初期化
    partial_message = ""

    for new_token in streamer:
        partial_message += new_token
        new_history = history + [(message, partial_message)]
        # 入力テキストをクリアする
        yield "", new_history


def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
    reply_ids: list[str],
):
    if len(reply_ids) == 0:
        reply_ids = [generate_unique_id()]
    last_reply_id = reply_ids[-1]
    user_reply_id = generate_unique_id()
    assistant_reply_id = generate_unique_id()

    reply_ids.append(user_reply_id)
    reply_ids.append(assistant_reply_id)

    for stream in generate(
        message,
        history,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
    ):
        yield *stream, reply_ids

    # 記録を取る
    if len(reply_ids) == 3:
        send_reply(reply_ids[0], "", "system", system_message)
    send_reply(user_reply_id, last_reply_id, "user", message)
    send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])


def retry(
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
    reply_ids: list[str],
):
    # 最後のメッセージを削除
    last_conversation = history[-1]
    user_message = last_conversation[0]
    history = history[:-1]

    user_reply_id = reply_ids[-2]
    reply_ids = reply_ids[:-1]
    assistant_reply_id = generate_unique_id()
    reply_ids.append(assistant_reply_id)

    for stream in generate(
        user_message,
        history,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
    ):
        yield *stream, reply_ids

    # 記録を取る
    send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])


def like_reponse(like_data: gr.LikeData, reply_ids: list[str]):
    # print(like_data.index, like_data.value, like_data.liked)
    assert isinstance(like_data.index, list)
    # 評価を送信
    send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1)


def demo():
    with gr.Blocks() as ui:

        gr.Markdown(
            """\
# Tanuki 8B Instruct デモ
モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct

アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。

## 注意点
**本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。
"""
        )

        reply_ids = gr.State(value=[generate_unique_id()])

        chat_history = gr.Chatbot(value=[])

        with gr.Row():
            retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm")
            clear_btn = gr.ClearButton(
                components=[chat_history], value="🗑️ 削除", scale=1, size="sm"
            )

        with gr.Group():
            with gr.Row():
                input_text = gr.Textbox(
                    value="",
                    placeholder="質問を入力してください...",
                    show_label=False,
                    scale=8,
                )
                start_btn = gr.Button(
                    value="送信",
                    variant="primary",
                    scale=1,
                )
            gr.Markdown(
                value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。"
            )

        with gr.Accordion(label="詳細設定", open=False):
            system_prompt_text = gr.Textbox(
                label="システムプロンプト",
                value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
            )
            max_new_tokens_slider = gr.Slider(
                minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
            )
            temperature_slider = gr.Slider(
                minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
            )
            top_p_slider = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p",
            )
            top_k_slider = gr.Slider(
                minimum=1, maximum=2000, value=250, step=10, label="Top-k"
            )

        gr.Examples(
            examples=[
                ["たぬきってなんですか?"],
                ["情けは人の為ならずとはどういう意味ですか?"],
                ["まどマギで一番可愛いのは誰?"],
                ["明晰夢とはなんですか?"],
                [
                    "シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"
                ],
            ],
            inputs=[input_text],
            cache_examples=False,
        )

        start_btn.click(
            respond,
            inputs=[
                input_text,
                chat_history,
                system_prompt_text,
                max_new_tokens_slider,
                temperature_slider,
                top_p_slider,
                top_k_slider,
                reply_ids,
            ],
            outputs=[input_text, chat_history, reply_ids],
        )
        input_text.submit(
            respond,
            inputs=[
                input_text,
                chat_history,
                system_prompt_text,
                max_new_tokens_slider,
                temperature_slider,
                top_p_slider,
                top_k_slider,
                reply_ids,
            ],
            outputs=[input_text, chat_history, reply_ids],
        )
        retry_btn.click(
            retry,
            inputs=[
                chat_history,
                system_prompt_text,
                max_new_tokens_slider,
                temperature_slider,
                top_p_slider,
                top_k_slider,
                reply_ids,
            ],
            outputs=[input_text, chat_history, reply_ids],
        )

        # 評価されたら
        chat_history.like(like_reponse, inputs=[reply_ids], outputs=None)

        clear_btn.click(
            lambda: [generate_unique_id()],  # system_message用のIDを生成
            outputs=[reply_ids],
        )

    ui.launch()


if __name__ == "__main__":
    demo()