Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from collections import defaultdict | |
import gradio as gr | |
from optimum.onnxruntime import ORTModelForCausalLM | |
import itertools | |
import regex as re | |
import logging | |
user_token = "<User>" | |
eos_token = "<EOS>" | |
bos_token = "<BOS>" | |
bot_token = "<Assistant>" | |
logger = logging.getLogger() | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter( | |
'%(asctime)s %(name)-12s %(levelname)-8s %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
max_context_length = 750 | |
def format(history): | |
prompt = bos_token | |
for idx, txt in enumerate(history): | |
if idx % 2 == 0: | |
prompt += f"{user_token}{txt}{eos_token}" | |
else: | |
prompt += f"{bot_token}{txt}" | |
prompt += bot_token | |
logger.info(prompt) | |
return prompt | |
def remove_spaces_between_chinese(text): | |
rex = r"(?<![a-zA-Z]{2})(?<=[a-zA-Z]{1})[ ]+(?=[a-zA-Z] |.$)|(?<=\p{Han}) +" | |
return re.sub(rex, "", text, 0, re.MULTILINE | re.UNICODE) | |
def gradio(model, tokenizer): | |
def response( | |
user_input, | |
chat_history, | |
top_k, | |
top_p, | |
temperature, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
): | |
history = list(itertools.chain(*chat_history)) | |
history.append(user_input) | |
prompt = format(history) | |
input_ids = tokenizer.encode( | |
prompt, | |
return_tensors="pt", | |
add_special_tokens=False, | |
)[:, -max_context_length:] | |
prompt_length = input_ids.shape[1] | |
beam_output = model.generate( | |
input_ids, | |
pad_token_id=tokenizer.pad_token_id, | |
max_new_tokens=250, | |
num_beams=1, # with cpu | |
top_k=top_k, | |
top_p=top_p, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
early_stopping=True, | |
do_sample=True | |
) | |
output = beam_output[0][prompt_length:] | |
generated = remove_spaces_between_chinese(tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)) | |
return generated | |
bot = gr.Chatbot(show_copy_button=True, show_share_button=True, height="2000") | |
with gr.Blocks() as demo: | |
gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss") | |
with gr.Accordion("Parameters in generation", open=False): | |
with gr.Row(): | |
top_k = gr.Slider( | |
2.0, | |
100.0, | |
label="top_k", | |
step=1, | |
value=50, | |
info="Limit the number of candidate tokens considered during decoding.", | |
) | |
top_p = gr.Slider( | |
0.1, | |
1.0, | |
label="top_p", | |
value=0.9, | |
info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.", | |
) | |
temperature = gr.Slider( | |
0.1, | |
2.0, | |
label="temperature", | |
value=0.9, | |
info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.", | |
) | |
repetition_penalty = gr.Slider( | |
0.1, | |
2.0, | |
label="repetition_penalty", | |
value=1.2, | |
info="Discourage the model from generating repetitive tokens in a sequence.", | |
) | |
no_repeat_ngram_size = gr.Slider( | |
0, | |
100, | |
label="no_repeat_ngram_size", | |
step=1, | |
value=5, | |
info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ", | |
) | |
gr.ChatInterface( | |
response, | |
chatbot=bot, | |
additional_inputs=[ | |
top_k, | |
top_p, | |
temperature, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
], | |
retry_btn = "🔄 Regenerate", | |
undo_btn = "↩️ Remove last turn", | |
clear_btn = "➕ New conversation", | |
examples=[ | |
["写一篇介绍人工智能的文章。", 30, 0.9, 0.95, 1.2, 5], | |
["给我讲一个笑话。", 50, 0.8, 0.9, 1.2, 6], | |
["Can you describe spring in English?", 50, 0.9, 1.0, 1, 5] | |
] | |
) | |
demo.queue().launch() | |
tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese") | |
model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True) | |
gradio(model, tokenizer) | |