Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
import os; os.chdir(os.path.dirname(__file__)) | |
# model_name = "./92M_low_kv_dropout_v3_hf" | |
model_name = "fzmnm/TinyStoriesAdv_v2_92M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.eval() | |
model.generation_config.pad_token_id = tokenizer.eos_token_id | |
max_tokens = 512 | |
def build_input_str(message: str, history: 'list[list[str]]'): | |
history_str = "" | |
for entity in history: | |
if entity['role'] == 'user': | |
history_str += f"问:{entity['content']}\n\n" | |
elif entity['role'] == 'assistant': | |
history_str += f"答:{entity['content']}\n\n" | |
return history_str + f"问:{message}\n\n" | |
def stop_criteria(input_str): | |
# return input_str.endswith("\n") and len(input_str.strip()) > 0 | |
input_str=input_str.replace(":",":") | |
return input_str.endswith("问:") or input_str.endswith("meta_tag:") | |
def remove_ending(input_str): | |
if input_str.replace(":",":").endswith("问:"): | |
return input_str[:-2] | |
if input_str.endswith("meta_tag:"): | |
return input_str[:-9] | |
return input_str | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
input_str = tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
return stop_criteria(input_str) | |
def chat(message, history): | |
input_str = build_input_str(message, history) | |
input_ids = tokenizer.encode(input_str, return_tensors="pt") | |
input_ids = input_ids[:, -max_tokens:] | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
timeout=10, | |
skip_prompt=True, | |
skip_special_tokens=True) | |
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
stopping_criteria=stopping_criteria, | |
max_new_tokens=512, | |
top_p=0.9, | |
do_sample=True, | |
temperature=0.7 | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
output_str = "" | |
for new_str in streamer: | |
output_str += new_str | |
yield remove_ending(output_str) | |
app = gr.ChatInterface( | |
fn=chat, | |
type='messages', | |
examples=['什么是鹦鹉?', '什么是大象?', '谁是李白?', '什么是黑洞?'], | |
title='聊天机器人', | |
) | |
app.launch() |