joaogante's picture
joaogante HF Staff
update model to pythia
588b2d4
raw
history blame
3.28 kB
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
model_id = "EleutherAI/pythia-6.9b-deduped"
assistant_id = "EleutherAI/pythia-70m-deduped"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
if torch_device == "cuda":
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
assistant_model = AutoModelForSeq2SeqLM.from_pretrained(assistant_id).to(torch_device)
def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
# Get the model and tokenizer, and tokenize the user text.
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=float(temperature),
top_k=top_k
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Pull the generated text from the streamer, and update the model output.
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
return model_output
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"# 🤗 Assisted Generation Demo\n"
f"Model: {model_id}\n"
f"Assistant Model: {assistant_id}"
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder="Write an email about an alpaca that likes flan",
label="User input"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
)
top_p = gr.Slider(
minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
)
top_k = gr.Slider(
minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
)
temperature = gr.Slider(
minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
)
user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
demo.queue(max_size=32).launch(enable_queue=True)