Spaces:
Sleeping
Sleeping
import deepsparse | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
import time | |
import gradio as gr | |
from typing import Tuple, List | |
deepsparse.cpu.print_hardware_capability() | |
MODEL_PATH = "TinyStories-1M" | |
DESCRIPTION = f""" | |
# TinyStories - DeepSparse | |
The model stub for this example is: {MODEL_PATH} | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 128 | |
def clear_and_save_textbox(message: str) -> Tuple[str, str]: | |
return "", message | |
def display_input( | |
message: str, history: List[Tuple[str, str]] | |
) -> List[Tuple[str, str]]: | |
history.append((message, "")) | |
return history | |
def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: | |
try: | |
message, _ = history.pop() | |
except IndexError: | |
message = "" | |
return history, message or "" | |
# Setup the engine | |
pipe = deepsparse.Pipeline.create( | |
task="text-generation", | |
model_path=MODEL_PATH, | |
max_generated_tokens=DEFAULT_MAX_NEW_TOKENS, | |
sequence_length=MAX_MAX_NEW_TOKENS, | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
chatbot = gr.Chatbot(label="Chatbot") | |
with gr.Row(): | |
textbox = gr.Textbox( | |
container=False, | |
show_label=False, | |
placeholder="Type a message...", | |
scale=10, | |
) | |
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0) | |
with gr.Row(): | |
retry_button = gr.Button("π Retry", variant="secondary") | |
undo_button = gr.Button("β©οΈ Undo", variant="secondary") | |
clear_button = gr.Button("ποΈ Clear", variant="secondary") | |
saved_input = gr.State() | |
gr.Examples( | |
examples=["Once upon a time"], | |
inputs=[textbox], | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=1.0, | |
) | |
# Generation inference | |
def generate(message, history, max_new_tokens: int, temperature: float): | |
streamer = TextIteratorStreamer(pipe.tokenizer) | |
pipe.max_generated_tokens = max_new_tokens | |
pipe.sampling_temperature = temperature | |
generation_kwargs = dict(sequences=message, streamer=streamer) | |
thread = Thread(target=pipe, kwargs=generation_kwargs) | |
thread.start() | |
for new_text in streamer: | |
history[-1][1] += new_text | |
yield history | |
thread.join() | |
print(pipe.timer_manager) | |
# Hooking up all the buttons | |
textbox.submit( | |
fn=clear_and_save_textbox, | |
inputs=textbox, | |
outputs=[textbox, saved_input], | |
api_name=False, | |
queue=False, | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
api_name=False, | |
queue=False, | |
).success( | |
generate, | |
inputs=[saved_input, chatbot, max_new_tokens, temperature], | |
outputs=[chatbot], | |
api_name=False, | |
) | |
submit_button.click( | |
fn=clear_and_save_textbox, | |
inputs=textbox, | |
outputs=[textbox, saved_input], | |
api_name=False, | |
queue=False, | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
api_name=False, | |
queue=False, | |
).success( | |
generate, | |
inputs=[saved_input, chatbot, max_new_tokens, temperature], | |
outputs=[chatbot], | |
api_name=False, | |
) | |
retry_button.click( | |
fn=delete_prev_fn, | |
inputs=chatbot, | |
outputs=[chatbot, saved_input], | |
api_name=False, | |
queue=False, | |
).then( | |
fn=display_input, | |
inputs=[saved_input, chatbot], | |
outputs=chatbot, | |
api_name=False, | |
queue=False, | |
).then( | |
generate, | |
inputs=[saved_input, chatbot, max_new_tokens, temperature], | |
outputs=[chatbot], | |
api_name=False, | |
) | |
undo_button.click( | |
fn=delete_prev_fn, | |
inputs=chatbot, | |
outputs=[chatbot, saved_input], | |
api_name=False, | |
queue=False, | |
).then( | |
fn=lambda x: x, | |
inputs=[saved_input], | |
outputs=textbox, | |
api_name=False, | |
queue=False, | |
) | |
clear_button.click( | |
fn=lambda: ([], ""), | |
outputs=[chatbot, saved_input], | |
queue=False, | |
api_name=False, | |
) | |
demo.queue().launch() | |