Plat
chore: report conversation data
fd490fb
raw
history blame
10.3 kB
try:
import flash_attn
except:
import subprocess
print("Installing flash-attn...")
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import flash_attn
print("flash-attn installed.")
import os
import uuid
import requests
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig,
)
from threading import Thread
import gradio as gr
from dotenv import load_dotenv
import spaces
load_dotenv()
MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct"
PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL")
assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, quantization_config=quantization_config, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Compiling model...")
model = torch.compile(model)
print("Model compiled.")
def send_report(
type: str,
data: dict,
):
print(f"Sending report: {data}")
try:
res = requests.post(PREFERENCE_API_URL, json={"type": type, **data})
print(f"Report sent: {res.json()}")
except Exception as e:
print(f"Failed to send report: {e}")
def send_reply(
reply_id: str,
parent_id: str,
role: str,
body: str,
):
send_report(
"conversation",
{
"reply_id": reply_id,
"parent_id": parent_id,
"role": role,
"body": body,
},
)
def send_score(
reply_id: str,
score: int,
):
# print(f"Score: {score}, reply_id: {reply_id}")
send_report(
"score",
{
"reply_id": reply_id,
"score": score,
},
)
def generate_unique_id():
return str(uuid.uuid4())
@spaces.GPU(duration=45)
def generate(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
):
if not message or message.strip() == "":
return "", history
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
tokenized_input = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=tokenized_input,
streamer=streamer,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# 返す値を初期化
partial_message = ""
for new_token in streamer:
partial_message += new_token
new_history = history + [(message, partial_message)]
# 入力テキストをクリアする
yield "", new_history
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
reply_ids: list[str],
):
if len(reply_ids) == 0:
reply_ids = [generate_unique_id()]
last_reply_id = reply_ids[-1]
user_reply_id = generate_unique_id()
assistant_reply_id = generate_unique_id()
reply_ids.append(user_reply_id)
reply_ids.append(assistant_reply_id)
for stream in generate(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
top_k,
):
yield *stream, reply_ids
# 記録を取る
if len(reply_ids) == 3:
send_reply(reply_ids[0], "", "system", system_message)
send_reply(user_reply_id, last_reply_id, "user", message)
send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
def retry(
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
reply_ids: list[str],
):
# 最後のメッセージを削除
last_conversation = history[-1]
user_message = last_conversation[0]
history = history[:-1]
user_reply_id = reply_ids[-2]
reply_ids = reply_ids[:-1]
assistant_reply_id = generate_unique_id()
reply_ids.append(assistant_reply_id)
for stream in generate(
user_message,
history,
system_message,
max_tokens,
temperature,
top_p,
top_k,
):
yield *stream, reply_ids
# 記録を取る
send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1])
def like_reponse(like_data: gr.LikeData, reply_ids: list[str]):
# print(like_data.index, like_data.value, like_data.liked)
assert isinstance(like_data.index, list)
# 評価を送信
send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1)
def demo():
with gr.Blocks() as ui:
gr.Markdown(
"""\
# Tanuki 8B Instruct デモ
モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct
アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。
## 注意点
**本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。
"""
)
reply_ids = gr.State(value=[generate_unique_id()])
chat_history = gr.Chatbot(value=[])
with gr.Row():
retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm")
clear_btn = gr.ClearButton(
components=[chat_history], value="🗑️ 削除", scale=1, size="sm"
)
with gr.Group():
with gr.Row():
input_text = gr.Textbox(
value="",
placeholder="質問を入力してください...",
show_label=False,
scale=8,
)
start_btn = gr.Button(
value="送信",
variant="primary",
scale=1,
)
gr.Markdown(
value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。"
)
with gr.Accordion(label="詳細設定", open=False):
system_prompt_text = gr.Textbox(
label="システムプロンプト",
value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
)
max_new_tokens_slider = gr.Slider(
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
)
top_k_slider = gr.Slider(
minimum=1, maximum=2000, value=250, step=10, label="Top-k"
)
gr.Examples(
examples=[
["たぬきってなんですか?"],
["情けは人の為ならずとはどういう意味ですか?"],
["まどマギで一番可愛いのは誰?"],
["明晰夢とはなんですか?"],
[
"シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?"
],
],
inputs=[input_text],
cache_examples=False,
)
start_btn.click(
respond,
inputs=[
input_text,
chat_history,
system_prompt_text,
max_new_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
reply_ids,
],
outputs=[input_text, chat_history, reply_ids],
)
input_text.submit(
respond,
inputs=[
input_text,
chat_history,
system_prompt_text,
max_new_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
reply_ids,
],
outputs=[input_text, chat_history, reply_ids],
)
retry_btn.click(
retry,
inputs=[
chat_history,
system_prompt_text,
max_new_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
reply_ids,
],
outputs=[input_text, chat_history, reply_ids],
)
# 評価されたら
chat_history.like(like_reponse, inputs=[reply_ids], outputs=None)
clear_btn.click(
lambda: [generate_unique_id()], # system_message用のIDを生成
outputs=[reply_ids],
)
ui.launch()
if __name__ == "__main__":
demo()