gpt2-chatbot / app.py
greyfoss's picture
Update app.py
88c96b4 verified
raw
history blame
4.46 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from collections import defaultdict
import gradio as gr
from optimum.onnxruntime import ORTModelForCausalLM
import itertools
import re
user_token = "<User>"
eos_token = "<EOS>"
bos_token = "<BOS>"
bot_token = "<Assistant>"
def is_english_word(tested_string):
pattern = re.compile(r"^[a-zA-Z]+$")
return pattern.match(tested_string) is not None
def format(history):
prompt = bos_token
for idx, txt in enumerate(history):
if idx % 2 == 0:
prompt += f"{user_token}{txt}{eos_token}"
else:
prompt += f"{bot_token}{txt}"
prompt += bot_token
print(prompt)
return prompt
def gradio(model, tokenizer):
def response(
user_input,
chat_history,
top_k,
top_p,
temperature,
repetition_penalty,
no_repeat_ngram_size,
):
history = list(itertools.chain(*chat_history))
history.append(user_input)
prompt = format(history)
input_ids = tokenizer.encode(
prompt,
return_tensors="pt",
add_special_tokens=False,
)
prompt_length = input_ids.shape[1]
beam_output = model.generate(
input_ids,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=255,
# num_beams=3,
top_k=top_k,
top_p=top_p,
no_repeat_ngram_size=no_repeat_ngram_size,
temperature=temperature,
repetition_penalty=repetition_penalty,
early_stopping=True,
# do_sample=True,
)
output = beam_output[0][prompt_length:]
tokens = tokenizer.convert_ids_to_tokens(output)
for i, token in enumerate(tokens[:-1]):
if is_english_word(token) and is_english_word(tokens[i + 1]):
tokens[i] = token + " "
text = "".join(tokens).replace("##", "").replace("<UNK>", "").strip()
return text
bot = gr.Chatbot(scale=8)
with gr.Blocks() as demo:
gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss")
with gr.Accordion("Parameters in generation", open=False):
with gr.Row():
top_k = gr.Slider(
2.0,
100.0,
label="top_k",
step=1,
value=50,
info="Limit the number of candidate tokens considered during decoding.",
)
top_p = gr.Slider(
0.1,
1.0,
label="top_p",
value=0.9,
info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.",
)
temperature = gr.Slider(
0.1,
2.0,
label="temperature",
value=0.9,
info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.",
)
repetition_penalty = gr.Slider(
0.1,
2.0,
label="repetition_penalty",
value=1.2,
info="Discourage the model from generating repetitive tokens in a sequence.",
)
no_repeat_ngram_size = gr.Slider(
0,
100,
label="no_repeat_ngram_size",
step=1,
value=5,
info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ",
)
gr.ChatInterface(
response,
chatbot=bot,
fill_vertical_space=True,
additional_inputs=[
top_k,
top_p,
temperature,
repetition_penalty,
no_repeat_ngram_size,
],
)
demo.queue().launch()
tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese")
model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True)
gradio(model, tokenizer)