Spaces:
Running
on
Zero
Running
on
Zero
make streaming response
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from itertools import islice
|
|
6 |
from datetime import datetime
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
10 |
from duckduckgo_search import DDGS
|
11 |
import spaces # Import spaces early to enable ZeroGPU support
|
12 |
|
@@ -134,19 +134,19 @@ def format_conversation(conversation, system_prompt):
|
|
134 |
return prompt
|
135 |
|
136 |
# ------------------------------
|
137 |
-
# Chat Response Generation with ZeroGPU using Pipeline
|
138 |
# ------------------------------
|
139 |
@spaces.GPU(duration=60)
|
140 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
141 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
142 |
"""
|
143 |
-
Generate a chat response by utilizing a transformers pipeline.
|
144 |
|
145 |
- Appends the user's message to the conversation history.
|
146 |
- Optionally retrieves web search context and inserts it as an additional system message.
|
147 |
- Converts the conversation into a formatted prompt to avoid leaking role labels.
|
148 |
-
- Uses the cached pipeline
|
149 |
-
-
|
150 |
"""
|
151 |
cancel_event.clear()
|
152 |
|
@@ -179,32 +179,43 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
179 |
conversation.append({"role": "assistant", "content": ""})
|
180 |
|
181 |
try:
|
182 |
-
# Format the entire conversation into a single prompt
|
183 |
prompt_text = format_conversation(conversation, system_prompt)
|
184 |
|
185 |
-
# Load the pipeline
|
186 |
pipe = load_pipeline(model_name)
|
|
|
|
|
|
|
187 |
|
188 |
-
#
|
189 |
-
|
190 |
-
prompt_text,
|
191 |
-
max_new_tokens=max_tokens,
|
192 |
-
temperature=temperature,
|
193 |
-
top_k=top_k,
|
194 |
-
top_p=top_p,
|
195 |
-
repetition_penalty=repeat_penalty,
|
196 |
-
)
|
197 |
|
198 |
-
#
|
199 |
-
|
200 |
-
# Remove the prompt portion so we only keep the new assistant reply.
|
201 |
-
assistant_text = generated[len(prompt_text):].strip()
|
202 |
|
203 |
-
#
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
#
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
except Exception as e:
|
209 |
conversation[-1]["content"] = f"Error: {e}"
|
210 |
yield conversation, debug_message
|
@@ -288,7 +299,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
288 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
289 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
290 |
|
291 |
-
# Submission: the chat_response function is used with
|
292 |
msg_input.submit(
|
293 |
fn=chat_response,
|
294 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|
|
|
6 |
from datetime import datetime
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
10 |
from duckduckgo_search import DDGS
|
11 |
import spaces # Import spaces early to enable ZeroGPU support
|
12 |
|
|
|
134 |
return prompt
|
135 |
|
136 |
# ------------------------------
|
137 |
+
# Chat Response Generation with ZeroGPU using Pipeline (Streaming Token-by-Token)
|
138 |
# ------------------------------
|
139 |
@spaces.GPU(duration=60)
|
140 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
141 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
142 |
"""
|
143 |
+
Generate a chat response by utilizing a transformers pipeline with streaming.
|
144 |
|
145 |
- Appends the user's message to the conversation history.
|
146 |
- Optionally retrieves web search context and inserts it as an additional system message.
|
147 |
- Converts the conversation into a formatted prompt to avoid leaking role labels.
|
148 |
+
- Uses the cached pipeline’s underlying model and tokenizer with a streamer to yield tokens as they are generated.
|
149 |
+
- Yields updated conversation history token by token.
|
150 |
"""
|
151 |
cancel_event.clear()
|
152 |
|
|
|
179 |
conversation.append({"role": "assistant", "content": ""})
|
180 |
|
181 |
try:
|
182 |
+
# Format the entire conversation into a single prompt.
|
183 |
prompt_text = format_conversation(conversation, system_prompt)
|
184 |
|
185 |
+
# Load the pipeline.
|
186 |
pipe = load_pipeline(model_name)
|
187 |
+
# Obtain the underlying tokenizer and model.
|
188 |
+
tokenizer = pipe.tokenizer
|
189 |
+
model = pipe.model
|
190 |
|
191 |
+
# Tokenize the formatted prompt.
|
192 |
+
model_inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
# Set up a streamer for token-by-token generation.
|
195 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
196 |
|
197 |
+
# Run generate in a background thread with the streamer.
|
198 |
+
gen_kwargs = {
|
199 |
+
"input_ids": model_inputs.input_ids,
|
200 |
+
"attention_mask": model_inputs.attention_mask,
|
201 |
+
"max_new_tokens": max_tokens,
|
202 |
+
"temperature": temperature,
|
203 |
+
"top_k": top_k,
|
204 |
+
"top_p": top_p,
|
205 |
+
"repetition_penalty": repeat_penalty,
|
206 |
+
"streamer": streamer
|
207 |
+
}
|
208 |
+
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
209 |
+
thread.start()
|
210 |
|
211 |
+
# Collect tokens from the streamer as they are generated.
|
212 |
+
assistant_text = ""
|
213 |
+
for new_text in streamer:
|
214 |
+
assistant_text += new_text
|
215 |
+
conversation[-1]["content"] = assistant_text
|
216 |
+
yield conversation, debug_message # Update UI token by token
|
217 |
+
|
218 |
+
thread.join()
|
219 |
except Exception as e:
|
220 |
conversation[-1]["content"] = f"Error: {e}"
|
221 |
yield conversation, debug_message
|
|
|
299 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
300 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
301 |
|
302 |
+
# Submission: the chat_response function is used with streaming.
|
303 |
msg_input.submit(
|
304 |
fn=chat_response,
|
305 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|