Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from collections import defaultdict | |
import gradio as gr | |
from optimum.onnxruntime import ORTModelForCausalLM | |
import itertools | |
import re | |
user_token = "<User>" | |
eos_token = "<EOS>" | |
bos_token = "<BOS>" | |
bot_token = "<Assistant>" | |
max_context_length = 750 | |
def is_english_word(tested_string): | |
pattern = re.compile(r"^[a-zA-Z]+$") | |
return pattern.match(tested_string) is not None | |
def format(history): | |
prompt = bos_token | |
for idx, txt in enumerate(history): | |
if idx % 2 == 0: | |
prompt += f"{user_token}{txt}{eos_token}" | |
else: | |
prompt += f"{bot_token}{txt}" | |
prompt += bot_token | |
print(prompt) | |
return prompt | |
def gradio(model, tokenizer): | |
def response( | |
user_input, | |
chat_history, | |
top_k, | |
top_p, | |
temperature, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
): | |
history = list(itertools.chain(*chat_history)) | |
history.append(user_input) | |
prompt = format(history) | |
input_ids = tokenizer.encode( | |
prompt, | |
return_tensors="pt", | |
add_special_tokens=False, | |
)[:, -max_context_length:] | |
prompt_length = input_ids.shape[1] | |
beam_output = model.generate( | |
input_ids, | |
pad_token_id=tokenizer.pad_token_id, | |
max_new_tokens=250, | |
num_beams=1, # with cpu | |
top_k=top_k, | |
top_p=top_p, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
early_stopping=True, | |
do_sample=True | |
) | |
output = beam_output[0][prompt_length:] | |
tokens = tokenizer.convert_ids_to_tokens(output) | |
for i, token in enumerate(tokens[:-1]): | |
if is_english_word(token) and is_english_word(tokens[i + 1]): | |
tokens[i] = token + " " | |
text = "".join(tokens).replace("##", "").replace("[UNK]", "").strip() | |
return text | |
bot = gr.Chatbot(show_copy_button=True, show_share_button=True) | |
with gr.Blocks() as demo: | |
gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss") | |
with gr.Accordion("Parameters in generation", open=False): | |
with gr.Row(): | |
top_k = gr.Slider( | |
2.0, | |
100.0, | |
label="top_k", | |
step=1, | |
value=50, | |
info="Limit the number of candidate tokens considered during decoding.", | |
) | |
top_p = gr.Slider( | |
0.1, | |
1.0, | |
label="top_p", | |
value=0.9, | |
info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.", | |
) | |
temperature = gr.Slider( | |
0.1, | |
2.0, | |
label="temperature", | |
value=0.9, | |
info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.", | |
) | |
repetition_penalty = gr.Slider( | |
0.1, | |
2.0, | |
label="repetition_penalty", | |
value=1.2, | |
info="Discourage the model from generating repetitive tokens in a sequence.", | |
) | |
no_repeat_ngram_size = gr.Slider( | |
0, | |
100, | |
label="no_repeat_ngram_size", | |
step=1, | |
value=5, | |
info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ", | |
) | |
gr.ChatInterface( | |
response, | |
chatbot=bot, | |
additional_inputs=[ | |
top_k, | |
top_p, | |
temperature, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
], | |
stop_btn = "🛑 Stop", | |
retry_btn = "🔄 Regenerate", | |
undo_btn = "↩️ Remove last turn", | |
clear_btn = "➕ New conversation", | |
examples=[[ | |
"帮我生成一句英文,描述春天的美好。", | |
"推荐一些好看的电影", | |
"Write a poem for me." | |
]] | |
) | |
demo.queue().launch() | |
tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese") | |
model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True) | |
gradio(model, tokenizer) | |