Luigi commited on
Commit
4c6b4c5
·
1 Parent(s): 4afc958

pin torch to 2.4.0

Browse files
Files changed (2) hide show
  1. app.py +19 -19
  2. requirements.txt +2 -2
app.py CHANGED
@@ -7,19 +7,18 @@ from datetime import datetime
7
  import gradio as gr
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from huggingface_hub import hf_hub_download
11
  from duckduckgo_search import DDGS
12
- import spaces
 
 
 
 
13
 
14
  # ------------------------------
15
  # Global Cancellation Event
16
  # ------------------------------
17
  cancel_event = threading.Event()
18
 
19
- # ------------------------------
20
- # Model Definitions and Global Variables (PyTorch/Transformers)
21
- # ------------------------------
22
- # Here, the repo_id should point to a model checkpoint that is compatible with Hugging Face Transformers.
23
  # ------------------------------
24
  # Torch-Compatible Model Definitions with Adjusted Descriptions
25
  # ------------------------------
@@ -70,7 +69,6 @@ MODELS = {
70
  },
71
  }
72
 
73
-
74
  LOADED_MODELS = {}
75
  CURRENT_MODEL_NAME = None
76
 
@@ -82,7 +80,7 @@ def load_model(model_name):
82
  if model_name in LOADED_MODELS:
83
  return LOADED_MODELS[model_name]
84
  selected_model = MODELS[model_name]
85
- # Load both the model and tokenizer using the Transformers library.
86
  model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
87
  tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
88
  LOADED_MODELS[model_name] = (model, tokenizer)
@@ -106,15 +104,15 @@ def retrieve_context(query, max_results=6, max_chars_per_result=600):
106
  return ""
107
 
108
  # ------------------------------
109
- # Chat Response Generation (Simulated Streaming) with Cancellation
110
  # ------------------------------
111
- @spaces.GPU
112
  def chat_response(user_message, chat_history, system_prompt, enable_search,
113
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
114
  # Reset the cancellation event.
115
  cancel_event.clear()
116
 
117
- # Prepare internal history.
118
  internal_history = list(chat_history) if chat_history else []
119
  internal_history.append({"role": "user", "content": user_message})
120
 
@@ -138,7 +136,7 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
138
  retrieved_context = ""
139
  debug_message = "Web search disabled."
140
 
141
- # Augment prompt with search context if available.
142
  if enable_search and retrieved_context:
143
  augmented_user_input = (
144
  f"{system_prompt.strip()}\n\n"
@@ -153,11 +151,13 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
153
  internal_history.append({"role": "assistant", "content": ""})
154
 
155
  try:
156
- # Load the PyTorch model and tokenizer.
157
  model, tokenizer = load_model(model_name)
 
 
 
 
158
 
159
- # Tokenize the input prompt.
160
- input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids
161
  with torch.no_grad():
162
  output_ids = model.generate(
163
  input_ids,
@@ -168,13 +168,12 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
168
  repetition_penalty=repeat_penalty,
169
  do_sample=True
170
  )
171
-
172
  # Decode the generated tokens.
173
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
174
- # Strip the original prompt to isolate the assistants reply.
175
  assistant_text = generated_text[len(augmented_user_input):].strip()
176
 
177
- # Simulate streaming by yielding the output word by word.
178
  words = assistant_text.split()
179
  assistant_message = ""
180
  for word in words:
@@ -205,7 +204,7 @@ def cancel_generation():
205
  with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
206
  gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
207
  gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
208
-
209
  with gr.Row():
210
  with gr.Column(scale=3):
211
  default_model = list(MODELS.keys())[0] if MODELS else "No models available"
@@ -252,6 +251,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
252
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
253
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
254
 
 
255
  msg_input.submit(
256
  fn=chat_response,
257
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
 
7
  import gradio as gr
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
  from duckduckgo_search import DDGS
11
+ import spaces # Import spaces early to enable ZeroGPU support
12
+
13
+ # Disable GPU visibility if you wish to force CPU usage outside of GPU functions
14
+ # (Not strictly needed for ZeroGPU as the decorator handles allocation)
15
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
16
 
17
  # ------------------------------
18
  # Global Cancellation Event
19
  # ------------------------------
20
  cancel_event = threading.Event()
21
 
 
 
 
 
22
  # ------------------------------
23
  # Torch-Compatible Model Definitions with Adjusted Descriptions
24
  # ------------------------------
 
69
  },
70
  }
71
 
 
72
  LOADED_MODELS = {}
73
  CURRENT_MODEL_NAME = None
74
 
 
80
  if model_name in LOADED_MODELS:
81
  return LOADED_MODELS[model_name]
82
  selected_model = MODELS[model_name]
83
+ # Load the model and tokenizer using Transformers.
84
  model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
85
  tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
86
  LOADED_MODELS[model_name] = (model, tokenizer)
 
104
  return ""
105
 
106
  # ------------------------------
107
+ # Chat Response Generation with ZeroGPU
108
  # ------------------------------
109
+ @spaces.GPU(duration=60) # This decorator triggers GPU allocation for up to 60 seconds.
110
  def chat_response(user_message, chat_history, system_prompt, enable_search,
111
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
112
  # Reset the cancellation event.
113
  cancel_event.clear()
114
 
115
+ # Prepare internal chat history.
116
  internal_history = list(chat_history) if chat_history else []
117
  internal_history.append({"role": "user", "content": user_message})
118
 
 
136
  retrieved_context = ""
137
  debug_message = "Web search disabled."
138
 
139
+ # Augment the prompt with search context if available.
140
  if enable_search and retrieved_context:
141
  augmented_user_input = (
142
  f"{system_prompt.strip()}\n\n"
 
151
  internal_history.append({"role": "assistant", "content": ""})
152
 
153
  try:
154
+ # Load the model and tokenizer.
155
  model, tokenizer = load_model(model_name)
156
+ # Move the model to GPU (using .to('cuda')) inside the GPU-decorated function.
157
+ model = model.to('cuda')
158
+ # Tokenize the augmented prompt and move input tensors to GPU.
159
+ input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids.to('cuda')
160
 
 
 
161
  with torch.no_grad():
162
  output_ids = model.generate(
163
  input_ids,
 
168
  repetition_penalty=repeat_penalty,
169
  do_sample=True
170
  )
 
171
  # Decode the generated tokens.
172
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
173
+ # Remove the original prompt to isolate the assistant's reply.
174
  assistant_text = generated_text[len(augmented_user_input):].strip()
175
 
176
+ # Simulate streaming output by yielding word-by-word.
177
  words = assistant_text.split()
178
  assistant_message = ""
179
  for word in words:
 
204
  with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
205
  gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
206
  gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
207
+
208
  with gr.Row():
209
  with gr.Column(scale=3):
210
  default_model = list(MODELS.keys())[0] if MODELS else "No models available"
 
251
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
252
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
253
 
254
+ # Submission: the chat_response function is now decorated with @spaces.GPU.
255
  msg_input.submit(
256
  fn=chat_response,
257
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  wheel
2
  streamlit
3
  duckduckgo_search
4
- gradio
5
- torch
6
  transformers
7
  spaces
8
  sentencepiece
 
1
  wheel
2
  streamlit
3
  duckduckgo_search
4
+ gradio>=4.0.0
5
+ torch==2.4.0
6
  transformers
7
  spaces
8
  sentencepiece