try: import flash_attn except: import subprocess print("Installing flash-attn...") subprocess.run( "uv install --system flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import flash_attn print("flash-attn installed.") import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, ) from threading import Thread import gradio as gr from dotenv import load_dotenv import spaces load_dotenv() HF_API_KEY = os.getenv("HF_API_KEY") MODEL_NAME = "weblab-GENIAC/Tanuki-8B-dpo-v1.0" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto", token=HF_API_KEY ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_API_KEY) print("Compiling model...") model = torch.compile(model) print("Model compiled.") @spaces.GPU(duration=30) def generate( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): if not message or message.strip() == "": return "", history messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) tokenized_input = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=tokenized_input, streamer=streamer, max_new_tokens=int(max_tokens), do_sample=True, temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # 返す値を初期化 partial_message = "" for new_token in streamer: partial_message += new_token new_history = history + [(message, partial_message)] # 入力テキストをクリアする yield "", new_history def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): for stream in generate( message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield *stream def retry( history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, ): # 最後のメッセージを削除 last_conversation = history[-1] user_message = last_conversation[0] history = history[:-1] for stream in generate( user_message, history, system_message, max_tokens, temperature, top_p, top_k, ): yield *stream def demo(): with gr.Blocks() as ui: gr.Markdown( """\ # weblab-GENIAC/Tanuki-8B-dpo-v1.0 デモ モデル: https://huggingface.co/weblab-GENIAC/Tanuki-8B-dpo-v1.0 """ ) chat_history = gr.Chatbot(value=[]) with gr.Row(): retry_btn = gr.Button(value="🔄 再生成", scale=1, size="sm") clear_btn = gr.ClearButton( components=[chat_history], value="🗑️ 削除", scale=1, size="sm" ) with gr.Group(): with gr.Row(): input_text = gr.Textbox( value="", placeholder="質問を入力してください...", show_label=False, scale=8, ) start_btn = gr.Button( value="送信", variant="primary", scale=1, ) gr.Markdown( value="※ 機密情報を入力しないでください。また、Tanuki は誤った情報を生成する可能性があります。" ) with gr.Accordion(label="詳細設定", open=False): system_prompt_text = gr.Textbox( label="システムプロンプト", value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。", ) max_new_tokens_slider = gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" ) temperature_slider = gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", ) top_k_slider = gr.Slider( minimum=1, maximum=2000, value=250, step=10, label="Top-k" ) gr.Examples( examples=[ ["たぬきってなんですか?"], ["情けは人の為ならずとはどういう意味ですか?"], ["まどマギで一番可愛いのは誰?"], ], inputs=[input_text], cache_examples=False, ) start_btn.click( respond, inputs=[ input_text, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, ], outputs=[input_text, chat_history], ) input_text.submit( respond, inputs=[ input_text, chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, ], outputs=[input_text, chat_history], ) retry_btn.click( retry, inputs=[ chat_history, system_prompt_text, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, ], outputs=[input_text, chat_history], ) ui.launch() if __name__ == "__main__": demo()