import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import torch from threading import Thread model_name = "fzmnm/TinyStoriesAdv_v2_92M" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() model.generation_config.pad_token_id = tokenizer.eos_token_id max_tokens = 512 def build_input_str(message: str, history: 'list[list[str]]'): history_str = "" for entity in history: if entity['role'] == 'user': history_str += f"问:{entity['content']}\n\n" elif entity['role'] == 'assistant': history_str += f"答:{entity['content']}\n\n" return history_str + f"问:{message}\n\n" def stop_criteria(input_str): return input_str.endswith("\n") and len(input_str.strip()) > 0 class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: input_str = tokenizer.decode(input_ids[0], skip_special_tokens=True) return stop_criteria(input_str) def chat(message, history): input_str = build_input_str(message, history) input_ids = tokenizer.encode(input_str, return_tensors="pt") input_ids = input_ids[:, -max_tokens:] streamer = TextIteratorStreamer( tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True) stopping_criteria = StoppingCriteriaList([StopOnTokens()]) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, stopping_criteria=stopping_criteria, max_new_tokens=512, top_p=0.9, do_sample=True, temperature=0.7 ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output_str = "" for new_str in streamer: output_str += new_str yield output_str app = gr.ChatInterface( fn=chat, type='messages', examples=['什么是鹦鹉?', '什么是大象?', '谁是李白?', '什么是黑洞?'], title='聊天机器人', ) app.launch()