try: import flash_attn except: import subprocess print("Installing flash-attn...") subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import flash_attn print("flash-attn installed.") import os import uuid import requests import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, ) from threading import Thread import gradio as gr from dotenv import load_dotenv import spaces load_dotenv() MODEL_NAME = "hatakeyama-llm-team/Tanuki-8B-Instruct" PREFERENCE_API_URL = os.getenv("PREFERENCE_API_URL") assert PREFERENCE_API_URL, "PREFERENCE_API_URL is not set" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) print("Compiling model...") model = torch.compile(model) print("Model compiled.") def send_report( type: str, data: dict, ): print(f"Sending report: {data}") try: res = requests.post(PREFERENCE_API_URL, json={"type": type, **data}) print(f"Report sent: {res.json()}") except Exception as e: print(f"Failed to send report: {e}") def send_reply( reply_id: str, parent_id: str, role: str, body: str, ): send_report( "conversation", { "reply_id": reply_id, "parent_id": parent_id, "role": role, "body": body, }, ) def send_score( reply_id: str, score: int, ): # print(f"Score: {score}, reply_id: {reply_id}") send_report( "score", { "reply_id": reply_id, "score": score, }, ) def generate_unique_id(): return str(uuid.uuid4()) @spaces.GPU(duration=45) def generate( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): if not message or message.strip() == "": return "", history messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) tokenized_input = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=tokenized_input, streamer=streamer, max_new_tokens=int(max_tokens), do_sample=True, temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # 返す値を初期化 partial_message = "" for new_token in streamer: partial_message += new_token new_history = history + [(message, partial_message)] # 入力テキストをクリアする yield "", new_history def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, reply_ids: list[str], ): if len(reply_ids) == 0: reply_ids = [generate_unique_id()] last_reply_id = reply_ids[-1] user_reply_id = generate_unique_id() assistant_reply_id = generate_unique_id() reply_ids.append(user_reply_id) reply_ids.append(assistant_reply_id) for stream in generate( message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield *stream, reply_ids # 記録を取る if len(reply_ids) == 3: send_reply(reply_ids[0], "", "system", system_message) send_reply(user_reply_id, last_reply_id, "user", message) send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1]) def retry( history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, reply_ids: list[str], ): # 最後のメッセージを削除 last_conversation = history[-1] user_message = last_conversation[0] history = history[:-1] user_reply_id = reply_ids[-2] reply_ids = reply_ids[:-1] assistant_reply_id = generate_unique_id() reply_ids.append(assistant_reply_id) for stream in generate( user_message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield *stream, reply_ids # 記録を取る send_reply(assistant_reply_id, user_reply_id, "assistant", stream[1][-1][1]) def like_reponse(like_data: gr.LikeData, reply_ids: list[str]): # print(like_data.index, like_data.value, like_data.liked) assert isinstance(like_data.index, list) # 評価を送信 send_score(reply_ids[like_data.index[0] + 1], 1 if like_data.liked else -1) def demo(): with gr.Blocks() as ui: gr.Markdown( """\ # Tanuki 8B Instruct デモ モデル: https://huggingface.co/hatakeyama-llm-team/Tanuki-8B-Instruct アシスタントの回答が不適切だと思った場合は **低評価ボタンを押して低評価を送信**、同様に、回答が素晴らしいと思った場合は**高評価ボタンを押して高評価を送信**することで、モデルの改善に貢献できます。 ## 注意点 **本デモに入力されたデータ・会話は匿名で全て記録されます**。これらのデータは Tanuki の学習に利用する可能性があります。そのため、**機密情報・個人情報を入力しないでください**。 """ ) reply_ids = gr.State(value=[generate_unique_id()]) chat_history = gr.Chatbot(value=[]) with gr.Row(): retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm") clear_btn = gr.ClearButton( components=[chat_history], value="🗑️ 削除", scale=1, size="sm" ) with gr.Group(): with gr.Row(): input_text = gr.Textbox( value="", placeholder="質問を入力してください...", show_label=False, scale=8, ) start_btn = gr.Button( value="送信", variant="primary", scale=1, ) gr.Markdown( value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。" ) with gr.Accordion(label="詳細設定", open=False): system_prompt_text = gr.Textbox( label="システムプロンプト", value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。", ) max_new_tokens_slider = gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" ) temperature_slider = gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", ) top_k_slider = gr.Slider( minimum=1, maximum=2000, value=250, step=10, label="Top-k" ) gr.Examples( examples=[ ["たぬきってなんですか?"], ["情けは人の為ならずとはどういう意味ですか?"], ["まどマギで一番可愛いのは誰?"], ["明晰夢とはなんですか?"], [ "シュレディンガー方程式とシュレディンガーの猫はどのような関係がありますか?" ], ], inputs=[input_text], cache_examples=False, ) start_btn.click( respond, inputs=[ input_text, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, reply_ids, ], outputs=[input_text, chat_history, reply_ids], ) input_text.submit( respond, inputs=[ input_text, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, reply_ids, ], outputs=[input_text, chat_history, reply_ids], ) retry_btn.click( retry, inputs=[ chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, reply_ids, ], outputs=[input_text, chat_history, reply_ids], ) # 評価されたら chat_history.like(like_reponse, inputs=[reply_ids], outputs=None) clear_btn.click( lambda: [generate_unique_id()], # system_message用のIDを生成 outputs=[reply_ids], ) ui.launch() if __name__ == "__main__": demo()