Luigi commited on
Commit
5ea073d
·
1 Parent(s): ef361b0

make streaming response

Browse files
Files changed (1) hide show
  1. app.py +36 -25
app.py CHANGED
@@ -6,7 +6,7 @@ from itertools import islice
6
  from datetime import datetime
7
  import gradio as gr
8
  import torch
9
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
10
  from duckduckgo_search import DDGS
11
  import spaces # Import spaces early to enable ZeroGPU support
12
 
@@ -134,19 +134,19 @@ def format_conversation(conversation, system_prompt):
134
  return prompt
135
 
136
  # ------------------------------
137
- # Chat Response Generation with ZeroGPU using Pipeline
138
  # ------------------------------
139
  @spaces.GPU(duration=60)
140
  def chat_response(user_message, chat_history, system_prompt, enable_search,
141
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
142
  """
143
- Generate a chat response by utilizing a transformers pipeline.
144
 
145
  - Appends the user's message to the conversation history.
146
  - Optionally retrieves web search context and inserts it as an additional system message.
147
  - Converts the conversation into a formatted prompt to avoid leaking role labels.
148
- - Uses the cached pipeline to generate a response.
149
- - Returns the updated conversation history and a debug message.
150
  """
151
  cancel_event.clear()
152
 
@@ -179,32 +179,43 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
179
  conversation.append({"role": "assistant", "content": ""})
180
 
181
  try:
182
- # Format the entire conversation into a single prompt (this fixes both issues).
183
  prompt_text = format_conversation(conversation, system_prompt)
184
 
185
- # Load the pipeline (cached) for the selected model.
186
  pipe = load_pipeline(model_name)
 
 
 
187
 
188
- # Generate a response using the formatted prompt.
189
- response = pipe(
190
- prompt_text,
191
- max_new_tokens=max_tokens,
192
- temperature=temperature,
193
- top_k=top_k,
194
- top_p=top_p,
195
- repetition_penalty=repeat_penalty,
196
- )
197
 
198
- # Extract the generated text.
199
- generated = response[0]["generated_text"]
200
- # Remove the prompt portion so we only keep the new assistant reply.
201
- assistant_text = generated[len(prompt_text):].strip()
202
 
203
- # Update the conversation history.
204
- conversation[-1]["content"] = assistant_text
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # Yield the complete conversation history and the debug message.
207
- yield conversation, debug_message
 
 
 
 
 
 
208
  except Exception as e:
209
  conversation[-1]["content"] = f"Error: {e}"
210
  yield conversation, debug_message
@@ -288,7 +299,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
288
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
289
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
290
 
291
- # Submission: the chat_response function is used with the Transformers pipeline.
292
  msg_input.submit(
293
  fn=chat_response,
294
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
 
6
  from datetime import datetime
7
  import gradio as gr
8
  import torch
9
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
10
  from duckduckgo_search import DDGS
11
  import spaces # Import spaces early to enable ZeroGPU support
12
 
 
134
  return prompt
135
 
136
  # ------------------------------
137
+ # Chat Response Generation with ZeroGPU using Pipeline (Streaming Token-by-Token)
138
  # ------------------------------
139
  @spaces.GPU(duration=60)
140
  def chat_response(user_message, chat_history, system_prompt, enable_search,
141
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
142
  """
143
+ Generate a chat response by utilizing a transformers pipeline with streaming.
144
 
145
  - Appends the user's message to the conversation history.
146
  - Optionally retrieves web search context and inserts it as an additional system message.
147
  - Converts the conversation into a formatted prompt to avoid leaking role labels.
148
+ - Uses the cached pipeline’s underlying model and tokenizer with a streamer to yield tokens as they are generated.
149
+ - Yields updated conversation history token by token.
150
  """
151
  cancel_event.clear()
152
 
 
179
  conversation.append({"role": "assistant", "content": ""})
180
 
181
  try:
182
+ # Format the entire conversation into a single prompt.
183
  prompt_text = format_conversation(conversation, system_prompt)
184
 
185
+ # Load the pipeline.
186
  pipe = load_pipeline(model_name)
187
+ # Obtain the underlying tokenizer and model.
188
+ tokenizer = pipe.tokenizer
189
+ model = pipe.model
190
 
191
+ # Tokenize the formatted prompt.
192
+ model_inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
193
 
194
+ # Set up a streamer for token-by-token generation.
195
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
196
 
197
+ # Run generate in a background thread with the streamer.
198
+ gen_kwargs = {
199
+ "input_ids": model_inputs.input_ids,
200
+ "attention_mask": model_inputs.attention_mask,
201
+ "max_new_tokens": max_tokens,
202
+ "temperature": temperature,
203
+ "top_k": top_k,
204
+ "top_p": top_p,
205
+ "repetition_penalty": repeat_penalty,
206
+ "streamer": streamer
207
+ }
208
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
209
+ thread.start()
210
 
211
+ # Collect tokens from the streamer as they are generated.
212
+ assistant_text = ""
213
+ for new_text in streamer:
214
+ assistant_text += new_text
215
+ conversation[-1]["content"] = assistant_text
216
+ yield conversation, debug_message # Update UI token by token
217
+
218
+ thread.join()
219
  except Exception as e:
220
  conversation[-1]["content"] = f"Error: {e}"
221
  yield conversation, debug_message
 
299
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
300
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
301
 
302
+ # Submission: the chat_response function is used with streaming.
303
  msg_input.submit(
304
  fn=chat_response,
305
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,