BLOOMZ_Compare / app.py
gururise's picture
update gradio interface for iterative outputs
26fd787
raw
history blame
6.98 kB
import gradio as gr
import threading
import codecs
from datetime import datetime
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
kill = threading.Event()
def stop_threads():
global kill
print("Force stopping threads")
kill.set()
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
global output
if kill.is_set():
return
flag = False
token_cnt = 0
with models[model_name][1].inference_session(max_length=512) as sess:
print(f"Thread Start -> {threading.get_ident()}")
output[model_name] = ""
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
while not done and not kill.is_set():
outputs = models[model_name][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop) -> {threading.get_ident()}")
done = True
if flag or (token_cnt >= max_tokens):
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nThread End -> {threading.get_ident()}")
def to_md(text):
return text.replace("\n", "<br />")
threads = list()
def infer(
prompt,
model_idx = ["BLOOM","BLOOMZ"],
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global threads
global output
global models
if len(model_idx) == 0:
return
kill.clear()
print("Loading Models\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
if models[model_name] == None:
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
model = model.to(DEVICE)
models[model_name] = tokenizer, model
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
threads.append(x)
x.start()
# Join Threads
for model_name, thread in enumerate(threads):
while thread.is_alive():
thread.join(timeout=0.2)
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Natural Language Interface
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]
def clear_prompt():
return "","",""
with gr.Blocks() as demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
with gr.Box():
chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
#min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
with gr.Column():
bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
with gr.Row():
btn_clear = gr.Button("Clear", variant="secondary")
btn_run = gr.Button("Run", variant="primary")
btn_stop = gr.Button("Stop", variant="stop")
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
btn_stop.click(stop_threads,cancels=click_run)
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
demo.queue(concurrency_count=3)
demo.launch()