mgoin's picture
Update app.py
ce89e91
raw
history blame
4.65 kB
import deepsparse
from transformers import TextIteratorStreamer
from threading import Thread
import time
import gradio as gr
from typing import Tuple, List
deepsparse.cpu.print_hardware_capability()
MODEL_PATH = "TinyStories-1M"
DESCRIPTION = f"""
# TinyStories running on DeepSparse
The model stub for this example is: {MODEL_PATH}
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 512
def clear_and_save_textbox(message: str) -> Tuple[str, str]:
return "", message
def display_input(
message: str, history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
history.append((message, ""))
return history
def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ""
return history, message or ""
# Setup the engine
pipe = deepsparse.Pipeline.create(
task="text-generation",
model_path=MODEL_PATH,
max_generated_tokens=DEFAULT_MAX_NEW_TOKENS,
sequence_length=MAX_MAX_NEW_TOKENS,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
chatbot = gr.Chatbot(label="Chatbot")
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder="Type a message...",
scale=10,
)
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
with gr.Row():
retry_button = gr.Button("πŸ”„ Retry", variant="secondary")
undo_button = gr.Button("↩️ Undo", variant="secondary")
clear_button = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
saved_input = gr.State()
gr.Examples(
examples=["Once upon a time"],
inputs=[textbox],
)
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=1.0,
)
# Generation inference
def generate(message, history, max_new_tokens: int, temperature: float):
streamer = TextIteratorStreamer(pipe.tokenizer)
pipe.max_generated_tokens = max_new_tokens
pipe.sampling_temperature = temperature
generation_kwargs = dict(sequences=message, streamer=streamer)
thread = Thread(target=pipe, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
history[-1][1] += new_text
yield history
thread.join()
print(pipe.timer_manager)
# Hooking up all the buttons
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ""),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue().launch()