Spaces:
Sleeping
Sleeping
try: | |
import flash_attn | |
except: | |
import subprocess | |
print("Installing flash-attn...") | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
import flash_attn | |
print("flash-attn installed.") | |
import os | |
import uuid | |
import requests | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
BitsAndBytesConfig, | |
) | |
from threading import Thread | |
import gradio as gr | |
from dotenv import load_dotenv | |
import spaces | |
load_dotenv() | |
MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct" | |
PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL") | |
assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, quantization_config=quantization_config, device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
print("Compiling model...") | |
model = torch.compile(model) | |
print("Model compiled.") | |
def send_report( | |
type: str, | |
data: dict, | |
): | |
print(f"Sending report: {data}") | |
try: | |
res = requests.post(PREFERENCE_API_URL, json={"type": type, **data}) | |
print(f"Report sent: {res.json()}") | |
except Exception as e: | |
print(f"Failed to send report: {e}") | |
def send_reply( | |
reply_id: str, | |
parent_id: str, | |
role: str, | |
body: str, | |
): | |
send_report( | |
"conversation", | |
{ | |
"reply_id": reply_id, | |
"parent_id": parent_id, | |
"role": role, | |
"body": body, | |
}, | |
) | |
def send_score( | |
reply_id: str, | |
score: int, | |
): | |
# print(f"Score: {score}, reply_id: {reply_id}") | |
send_report( | |
"score", | |
{ | |
"reply_id": reply_id, | |
"score": score, | |
}, | |
) | |
def generate_unique_id(): | |
return str(uuid.uuid4()) | |
def generate( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
): | |
if not message or message.strip() == "": | |
return "", history | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
tokenized_input = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" | |
).to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
input_ids=tokenized_input, | |
streamer=streamer, | |
max_new_tokens=int(max_tokens), | |
do_sample=True, | |
temperature=float(temperature), | |
top_k=int(top_k), | |
top_p=float(top_p), | |
num_beams=1, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# 返す値を初期化 | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token | |
new_history = history + [(message, partial_message)] | |
# 入力テキストをクリアする | |
yield "", new_history | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
reply_ids: list[str], | |
): | |
if len(reply_ids) == 0: | |
reply_ids = [generate_unique_id()] | |
last_reply_id = reply_ids[-1] | |
user_reply_id = generate_unique_id() | |
assistant_reply_id = generate_unique_id() | |
reply_ids.append(user_reply_id) | |
reply_ids.append(assistant_reply_id) | |
for stream in generate( | |
message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
top_k, | |
): | |
yield *stream, reply_ids | |
# 記録を取る | |
if len(reply_ids) == 3: | |
send_reply(reply_ids[0], "", "system", system_message) | |
send_reply(user_reply_id, last_reply_id, "user", message) | |
send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1]) | |
def retry( | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
reply_ids: list[str], | |
): | |
# 最後のメッセージを削除 | |
last_conversation = history[-1] | |
user_message = last_conversation[0] | |
history = history[:-1] | |
user_reply_id = reply_ids[-2] | |
reply_ids = reply_ids[:-1] | |
assistant_reply_id = generate_unique_id() | |
reply_ids.append(assistant_reply_id) | |
for stream in generate( | |
user_message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
top_k, | |
): | |
yield *stream, reply_ids | |
# 記録を取る | |
send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1]) | |
def like_reponse(like_data: gr.LikeData, reply_ids: list[str]): | |
# print(like_data.index, like_data.value, like_data.liked) | |
assert isinstance(like_data.index, list) | |
# 評価を送信 | |
send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1) | |
def demo(): | |
with gr.Blocks() as ui: | |
gr.Markdown( | |
"""\ | |
# Tanuki 8B Instruct デモ | |
モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct | |
アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。 | |
## 注意点 | |
**本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。 | |
""" | |
) | |
reply_ids = gr.State(value=[generate_unique_id()]) | |
chat_history = gr.Chatbot(value=[]) | |
with gr.Row(): | |
retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm") | |
clear_btn = gr.ClearButton( | |
components=[chat_history], value="🗑️ 削除", scale=1, size="sm" | |
) | |
with gr.Group(): | |
with gr.Row(): | |
input_text = gr.Textbox( | |
value="", | |
placeholder="質問を入力してください...", | |
show_label=False, | |
scale=8, | |
) | |
start_btn = gr.Button( | |
value="送信", | |
variant="primary", | |
scale=1, | |
) | |
gr.Markdown( | |
value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。" | |
) | |
with gr.Accordion(label="詳細設定", open=False): | |
system_prompt_text = gr.Textbox( | |
label="システムプロンプト", | |
value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。", | |
) | |
max_new_tokens_slider = gr.Slider( | |
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p", | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, maximum=2000, value=250, step=10, label="Top-k" | |
) | |
gr.Examples( | |
examples=[ | |
["たぬきってなんですか?"], | |
["情けは人の為ならずとはどういう意味ですか?"], | |
["まどマギで一番可愛いのは誰?"], | |
["明晰夢とはなんですか?"], | |
[ | |
"シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?" | |
], | |
], | |
inputs=[input_text], | |
cache_examples=False, | |
) | |
start_btn.click( | |
respond, | |
inputs=[ | |
input_text, | |
chat_history, | |
system_prompt_text, | |
max_new_tokens_slider, | |
temperature_slider, | |
top_p_slider, | |
top_k_slider, | |
reply_ids, | |
], | |
outputs=[input_text, chat_history, reply_ids], | |
) | |
input_text.submit( | |
respond, | |
inputs=[ | |
input_text, | |
chat_history, | |
system_prompt_text, | |
max_new_tokens_slider, | |
temperature_slider, | |
top_p_slider, | |
top_k_slider, | |
reply_ids, | |
], | |
outputs=[input_text, chat_history, reply_ids], | |
) | |
retry_btn.click( | |
retry, | |
inputs=[ | |
chat_history, | |
system_prompt_text, | |
max_new_tokens_slider, | |
temperature_slider, | |
top_p_slider, | |
top_k_slider, | |
reply_ids, | |
], | |
outputs=[input_text, chat_history, reply_ids], | |
) | |
# 評価されたら | |
chat_history.like(like_reponse, inputs=[reply_ids], outputs=None) | |
clear_btn.click( | |
lambda: [generate_unique_id()], # system_message用のIDを生成 | |
outputs=[reply_ids], | |
) | |
ui.launch() | |
if __name__ == "__main__": | |
demo() | |