# app.py import gradio as gr import spaces from threading import Thread import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) # ------------------------------ # 1. 加载模型与 Tokenizer # ------------------------------ model_name = "agentica-org/DeepScaleR-1.5B-Preview" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # ------------------------------ # 2. 对话历史 -> Prompt 格式 # ------------------------------ def preprocess_messages(history): """ 将聊天记录拼成一个最简单的 Prompt。 你可以自定义更适合该模型的提示格式或特殊 Token。 """ prompt = "" for user_msg, assistant_msg in history: if user_msg: prompt += f"User: {user_msg}\n" if assistant_msg: prompt += f"Assistant: {assistant_msg}\n" # 继续生成时,提示 "Assistant:" prompt += "Assistant: " return prompt # ------------------------------ # 3. 预测 / 推理函数 # ------------------------------ @spaces.GPU() # 让 huggingface spaces 调用 GPU def predict(history, max_length, top_p, temperature): """ 基于当前的 history 做文本生成。 使用 HF 提供的 TextIteratorStreamer 实现流式生成。 """ prompt = preprocess_messages(history) inputs = tokenizer( prompt, return_tensors="pt", padding=True, # 自动 padding truncation=True, # 超长截断 max_length=2048 # 你可根据显存大小或模型上限做调整 ) input_ids = inputs["input_ids"].to(model.device) attention_mask = inputs["attention_mask"].to(model.device) # 流式输出器 streamer = TextIteratorStreamer( tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "max_new_tokens": max_length, # 新生成的 token 数 "do_sample": True, "top_p": top_p, "temperature": temperature, "repetition_penalty": 1.2, "streamer": streamer, } # 在后台线程中执行 generate,主线程循环读取新 token t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # 将最新生成的 token 依次拼接到 history[-1][1] partial_output = "" for new_token in streamer: partial_output += new_token history[-1][1] = partial_output yield history # ------------------------------ # 4. Gradio UI # ------------------------------ def main(): with gr.Blocks() as demo: gr.HTML("

DeepScaleR-1.5B Chat Demo

") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=2): user_input = gr.Textbox( show_label=True, placeholder="请输入您的问题...", label="User Input" ) submitBtn = gr.Button("Submit") clearBtn = gr.Button("Clear History") with gr.Column(scale=1): max_length = gr.Slider( minimum=0, maximum=1024, # 可根据需要调大/调小 value=512, step=1, label="Max New Tokens", interactive=True ) top_p = gr.Slider( minimum=0, maximum=1, value=0.8, step=0.01, label="Top P", interactive=True ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature", interactive=True ) # 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成 def user(query, history): return "", history + [[query, ""]] submitBtn.click( fn=user, inputs=[user_input, chatbot], outputs=[user_input, chatbot], queue=False # 不排队 ).then( fn=predict, inputs=[chatbot, max_length, top_p, temperature], outputs=chatbot ) # 清空聊天记录 def clear_history(): return [], [] clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False) # 可选:启用队列防止并发冲突 demo.queue(concurrency_count=1) demo.launch() # ------------------------------ # 入口 # ------------------------------ if __name__ == "__main__": main()