Luigi commited on
Commit
ac8e9cc
·
1 Parent(s): f248fec

usue chat pipeline instead of model and tokenizer individually

Browse files
Files changed (1) hide show
  1. app.py +72 -92
app.py CHANGED
@@ -6,12 +6,11 @@ from itertools import islice
6
  from datetime import datetime
7
  import gradio as gr
8
  import torch
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from duckduckgo_search import DDGS
11
  import spaces # Import spaces early to enable ZeroGPU support
12
 
13
- # Disable GPU visibility if you wish to force CPU usage outside of GPU functions
14
- # (Not strictly needed for ZeroGPU as the decorator handles allocation)
15
  # os.environ["CUDA_VISIBLE_DEVICES"] = ""
16
 
17
  # ------------------------------
@@ -22,9 +21,6 @@ cancel_event = threading.Event()
22
  # ------------------------------
23
  # Torch-Compatible Model Definitions with Adjusted Descriptions
24
  # ------------------------------
25
- # ------------------------------
26
- # Torch-Compatible Model Definitions (Cleaned)
27
- # ------------------------------
28
  MODELS = {
29
  "Taiwan-tinyllama-v1.0-chat": {
30
  "repo_id": "DavidLanz/Taiwan-tinyllama-v1.0-chat",
@@ -72,34 +68,35 @@ MODELS = {
72
  },
73
  }
74
 
75
- LOADED_MODELS = {}
76
- CURRENT_MODEL_NAME = None
77
 
78
- # ------------------------------
79
- # Model Loading Helper Function (PyTorch/Transformers)
80
- # ------------------------------
81
- def load_model(model_name):
82
- global LOADED_MODELS, CURRENT_MODEL_NAME
83
- if model_name in LOADED_MODELS:
84
- return LOADED_MODELS[model_name]
 
85
  selected_model = MODELS[model_name]
86
- # Load the model and tokenizer using Transformers.
87
- model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
88
- tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
89
-
90
- # If the pad token is missing or the same as the eos token, add a new pad token.
91
- if tokenizer.pad_token is None or tokenizer.pad_token == tokenizer.eos_token:
92
- tokenizer.add_special_tokens({'pad_token': '<pad>'})
93
- model.resize_token_embeddings(len(tokenizer))
94
-
95
- LOADED_MODELS[model_name] = (model, tokenizer)
96
- CURRENT_MODEL_NAME = model_name
97
- return model, tokenizer
98
 
99
- # ------------------------------
100
- # Web Search Context Retrieval Function
101
- # ------------------------------
102
  def retrieve_context(query, max_results=6, max_chars_per_result=600):
 
 
 
 
103
  try:
104
  with DDGS() as ddgs:
105
  results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
@@ -113,23 +110,31 @@ def retrieve_context(query, max_results=6, max_chars_per_result=600):
113
  return ""
114
 
115
  # ------------------------------
116
- # Chat Response Generation with ZeroGPU
117
  # ------------------------------
118
- @spaces.GPU(duration=60) # This decorator triggers GPU allocation for up to 60 seconds.
119
  def chat_response(user_message, chat_history, system_prompt, enable_search,
120
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
121
- # Reset the cancellation event.
 
 
 
 
 
 
 
122
  cancel_event.clear()
123
 
124
- # Prepare internal chat history.
125
- internal_history = list(chat_history) if chat_history else []
126
- internal_history.append({"role": "user", "content": user_message})
127
 
128
- # Retrieve web search context (with debug feedback).
129
  debug_message = ""
 
130
  if enable_search:
131
  debug_message = "Initiating web search..."
132
- yield internal_history, debug_message
133
  search_result = [""]
134
  def do_search():
135
  search_result[0] = retrieve_context(user_message, max_results, max_chars)
@@ -139,71 +144,46 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
139
  retrieved_context = search_result[0]
140
  if retrieved_context:
141
  debug_message = f"Web search results:\n\n{retrieved_context}"
 
 
142
  else:
143
  debug_message = "Web search returned no results or timed out."
144
  else:
145
- retrieved_context = ""
146
  debug_message = "Web search disabled."
147
 
148
- # Augment the prompt with search context if available.
149
- if enable_search and retrieved_context:
150
- augmented_user_input = (
151
- f"{system_prompt.strip()}\n\n"
152
- "Use the following recent web search context to help answer the query:\n\n"
153
- f"{retrieved_context}\n\n"
154
- f"User Query: {user_message}"
155
- )
156
- else:
157
- augmented_user_input = f"{system_prompt.strip()}\n\nUser Query: {user_message}"
158
-
159
  # Append a placeholder for the assistant's response.
160
- internal_history.append({"role": "assistant", "content": ""})
161
 
162
  try:
163
- # Load the model and tokenizer.
164
- model, tokenizer = load_model(model_name)
165
- # Move the model to GPU (using .to('cuda')) inside the GPU-decorated function.
166
- model = model.to('cuda')
167
 
168
- # Tokenize the augmented prompt with padding and retrieve the attention mask.
169
- encoding = tokenizer(augmented_user_input, return_tensors="pt", padding=True)
170
- input_ids = encoding["input_ids"].to('cuda')
171
- attention_mask = encoding["attention_mask"].to('cuda')
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- with torch.no_grad():
174
- output_ids = model.generate(
175
- input_ids,
176
- attention_mask=attention_mask,
177
- max_new_tokens=max_tokens,
178
- temperature=temperature,
179
- top_k=top_k,
180
- top_p=top_p,
181
- repetition_penalty=repeat_penalty,
182
- do_sample=True,
183
- pad_token_id=tokenizer.pad_token_id,
184
- )
185
- # Decode the generated tokens.
186
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
187
- # Remove the original prompt to isolate the assistant's reply.
188
- assistant_text = generated_text[len(augmented_user_input):].strip()
189
 
190
- # Simulate streaming output by yielding word-by-word.
191
- words = assistant_text.split()
192
- assistant_message = ""
193
- for word in words:
194
- if cancel_event.is_set():
195
- assistant_message += "\n\n[Response generation cancelled by user]"
196
- internal_history[-1]["content"] = assistant_message
197
- yield internal_history, debug_message
198
- return
199
- assistant_message += word + " "
200
- internal_history[-1]["content"] = assistant_message
201
- yield internal_history, debug_message
202
- time.sleep(0.05) # Short delay to simulate streaming
203
  except Exception as e:
204
- internal_history[-1]["content"] = f"Error: {e}"
205
- yield internal_history, debug_message
206
- gc.collect()
 
207
 
208
  # ------------------------------
209
  # Cancel Function
@@ -265,7 +245,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
265
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
266
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
267
 
268
- # Submission: the chat_response function is now decorated with @spaces.GPU.
269
  msg_input.submit(
270
  fn=chat_response,
271
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
 
6
  from datetime import datetime
7
  import gradio as gr
8
  import torch
9
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
10
  from duckduckgo_search import DDGS
11
  import spaces # Import spaces early to enable ZeroGPU support
12
 
13
+ # Optional: Disable GPU visibility if you wish to force CPU usage
 
14
  # os.environ["CUDA_VISIBLE_DEVICES"] = ""
15
 
16
  # ------------------------------
 
21
  # ------------------------------
22
  # Torch-Compatible Model Definitions with Adjusted Descriptions
23
  # ------------------------------
 
 
 
24
  MODELS = {
25
  "Taiwan-tinyllama-v1.0-chat": {
26
  "repo_id": "DavidLanz/Taiwan-tinyllama-v1.0-chat",
 
68
  },
69
  }
70
 
71
+ # Global cache for pipelines to avoid re-loading.
72
+ PIPELINES = {}
73
 
74
+ def load_pipeline(model_name):
75
+ """
76
+ Load and cache a transformers pipeline for chat/text-generation.
77
+ Uses the model's repo_id from MODELS and caches the pipeline for future use.
78
+ """
79
+ global PIPELINES
80
+ if model_name in PIPELINES:
81
+ return PIPELINES[model_name]
82
  selected_model = MODELS[model_name]
83
+ # Create a chat-style text-generation pipeline.
84
+ pipe = pipeline(
85
+ task="text-generation",
86
+ model=selected_model["repo_id"],
87
+ tokenizer=selected_model["repo_id"],
88
+ trust_remote_code=True,
89
+ torch_dtype=torch.bfloat16,
90
+ device_map="auto"
91
+ )
92
+ PIPELINES[model_name] = pipe
93
+ return pipe
 
94
 
 
 
 
95
  def retrieve_context(query, max_results=6, max_chars_per_result=600):
96
+ """
97
+ Retrieve recent web search context for the given query using DuckDuckGo.
98
+ Returns a formatted string with search results.
99
+ """
100
  try:
101
  with DDGS() as ddgs:
102
  results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
 
110
  return ""
111
 
112
  # ------------------------------
113
+ # Chat Response Generation with ZeroGPU using Pipeline
114
  # ------------------------------
115
+ @spaces.GPU(duration=60)
116
  def chat_response(user_message, chat_history, system_prompt, enable_search,
117
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
118
+ """
119
+ Generate a chat response by utilizing a transformers pipeline.
120
+
121
+ - Appends the user's message to the conversation history.
122
+ - Optionally retrieves web search context and inserts it as an additional system message.
123
+ - Uses a cached pipeline (loaded via load_pipeline) to generate a response.
124
+ - Returns the updated conversation history and a debug message.
125
+ """
126
  cancel_event.clear()
127
 
128
+ # Build conversation list from chat history.
129
+ conversation = list(chat_history) if chat_history else []
130
+ conversation.append({"role": "user", "content": user_message})
131
 
132
+ # Retrieve web search context if enabled.
133
  debug_message = ""
134
+ retrieved_context = ""
135
  if enable_search:
136
  debug_message = "Initiating web search..."
137
+ yield conversation, debug_message
138
  search_result = [""]
139
  def do_search():
140
  search_result[0] = retrieve_context(user_message, max_results, max_chars)
 
144
  retrieved_context = search_result[0]
145
  if retrieved_context:
146
  debug_message = f"Web search results:\n\n{retrieved_context}"
147
+ # Insert the search context as a system-level message immediately after the original system prompt.
148
+ conversation.insert(1, {"role": "system", "content": f"Web search context:\n{retrieved_context}"})
149
  else:
150
  debug_message = "Web search returned no results or timed out."
151
  else:
 
152
  debug_message = "Web search disabled."
153
 
 
 
 
 
 
 
 
 
 
 
 
154
  # Append a placeholder for the assistant's response.
155
+ conversation.append({"role": "assistant", "content": ""})
156
 
157
  try:
158
+ # Load the pipeline (cached) for the selected model.
159
+ pipe = load_pipeline(model_name)
 
 
160
 
161
+ # Use the pipeline directly with conversation history.
162
+ # Note: Many chat pipelines use internal chat templating to properly format the conversation.
163
+ response = pipe(
164
+ conversation,
165
+ max_new_tokens=max_tokens,
166
+ temperature=temperature,
167
+ top_k=top_k,
168
+ top_p=top_p,
169
+ repetition_penalty=repeat_penalty,
170
+ )
171
+ # Extract the assistant's reply.
172
+ try:
173
+ assistant_text = response[0]["generated_text"][-1]["content"]
174
+ except (KeyError, IndexError, TypeError):
175
+ assistant_text = response[0]["generated_text"]
176
 
177
+ # Update the conversation history.
178
+ conversation[-1]["content"] = assistant_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ # Yield the complete conversation history and the debug message.
181
+ yield conversation, debug_message
 
 
 
 
 
 
 
 
 
 
 
182
  except Exception as e:
183
+ conversation[-1]["content"] = f"Error: {e}"
184
+ yield conversation, debug_message
185
+ finally:
186
+ gc.collect()
187
 
188
  # ------------------------------
189
  # Cancel Function
 
245
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
246
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
247
 
248
+ # Submission: the chat_response function is now used with the Transformers pipeline.
249
  msg_input.submit(
250
  fn=chat_response,
251
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,