Spaces:
Running
on
Zero
Running
on
Zero
pin torch to 2.4.0
Browse files- app.py +19 -19
- requirements.txt +2 -2
app.py
CHANGED
@@ -7,19 +7,18 @@ from datetime import datetime
|
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
from huggingface_hub import hf_hub_download
|
11 |
from duckduckgo_search import DDGS
|
12 |
-
import spaces
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# ------------------------------
|
15 |
# Global Cancellation Event
|
16 |
# ------------------------------
|
17 |
cancel_event = threading.Event()
|
18 |
|
19 |
-
# ------------------------------
|
20 |
-
# Model Definitions and Global Variables (PyTorch/Transformers)
|
21 |
-
# ------------------------------
|
22 |
-
# Here, the repo_id should point to a model checkpoint that is compatible with Hugging Face Transformers.
|
23 |
# ------------------------------
|
24 |
# Torch-Compatible Model Definitions with Adjusted Descriptions
|
25 |
# ------------------------------
|
@@ -70,7 +69,6 @@ MODELS = {
|
|
70 |
},
|
71 |
}
|
72 |
|
73 |
-
|
74 |
LOADED_MODELS = {}
|
75 |
CURRENT_MODEL_NAME = None
|
76 |
|
@@ -82,7 +80,7 @@ def load_model(model_name):
|
|
82 |
if model_name in LOADED_MODELS:
|
83 |
return LOADED_MODELS[model_name]
|
84 |
selected_model = MODELS[model_name]
|
85 |
-
# Load
|
86 |
model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
87 |
tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
88 |
LOADED_MODELS[model_name] = (model, tokenizer)
|
@@ -106,15 +104,15 @@ def retrieve_context(query, max_results=6, max_chars_per_result=600):
|
|
106 |
return ""
|
107 |
|
108 |
# ------------------------------
|
109 |
-
# Chat Response Generation
|
110 |
# ------------------------------
|
111 |
-
@spaces.GPU
|
112 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
113 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
114 |
# Reset the cancellation event.
|
115 |
cancel_event.clear()
|
116 |
|
117 |
-
# Prepare internal history.
|
118 |
internal_history = list(chat_history) if chat_history else []
|
119 |
internal_history.append({"role": "user", "content": user_message})
|
120 |
|
@@ -138,7 +136,7 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
138 |
retrieved_context = ""
|
139 |
debug_message = "Web search disabled."
|
140 |
|
141 |
-
# Augment prompt with search context if available.
|
142 |
if enable_search and retrieved_context:
|
143 |
augmented_user_input = (
|
144 |
f"{system_prompt.strip()}\n\n"
|
@@ -153,11 +151,13 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
153 |
internal_history.append({"role": "assistant", "content": ""})
|
154 |
|
155 |
try:
|
156 |
-
# Load the
|
157 |
model, tokenizer = load_model(model_name)
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
# Tokenize the input prompt.
|
160 |
-
input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids
|
161 |
with torch.no_grad():
|
162 |
output_ids = model.generate(
|
163 |
input_ids,
|
@@ -168,13 +168,12 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
168 |
repetition_penalty=repeat_penalty,
|
169 |
do_sample=True
|
170 |
)
|
171 |
-
|
172 |
# Decode the generated tokens.
|
173 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
174 |
-
#
|
175 |
assistant_text = generated_text[len(augmented_user_input):].strip()
|
176 |
|
177 |
-
# Simulate streaming by yielding
|
178 |
words = assistant_text.split()
|
179 |
assistant_message = ""
|
180 |
for word in words:
|
@@ -205,7 +204,7 @@ def cancel_generation():
|
|
205 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
206 |
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
|
207 |
gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
|
208 |
-
|
209 |
with gr.Row():
|
210 |
with gr.Column(scale=3):
|
211 |
default_model = list(MODELS.keys())[0] if MODELS else "No models available"
|
@@ -252,6 +251,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
252 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
253 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
254 |
|
|
|
255 |
msg_input.submit(
|
256 |
fn=chat_response,
|
257 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|
|
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
10 |
from duckduckgo_search import DDGS
|
11 |
+
import spaces # Import spaces early to enable ZeroGPU support
|
12 |
+
|
13 |
+
# Disable GPU visibility if you wish to force CPU usage outside of GPU functions
|
14 |
+
# (Not strictly needed for ZeroGPU as the decorator handles allocation)
|
15 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
16 |
|
17 |
# ------------------------------
|
18 |
# Global Cancellation Event
|
19 |
# ------------------------------
|
20 |
cancel_event = threading.Event()
|
21 |
|
|
|
|
|
|
|
|
|
22 |
# ------------------------------
|
23 |
# Torch-Compatible Model Definitions with Adjusted Descriptions
|
24 |
# ------------------------------
|
|
|
69 |
},
|
70 |
}
|
71 |
|
|
|
72 |
LOADED_MODELS = {}
|
73 |
CURRENT_MODEL_NAME = None
|
74 |
|
|
|
80 |
if model_name in LOADED_MODELS:
|
81 |
return LOADED_MODELS[model_name]
|
82 |
selected_model = MODELS[model_name]
|
83 |
+
# Load the model and tokenizer using Transformers.
|
84 |
model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
85 |
tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
86 |
LOADED_MODELS[model_name] = (model, tokenizer)
|
|
|
104 |
return ""
|
105 |
|
106 |
# ------------------------------
|
107 |
+
# Chat Response Generation with ZeroGPU
|
108 |
# ------------------------------
|
109 |
+
@spaces.GPU(duration=60) # This decorator triggers GPU allocation for up to 60 seconds.
|
110 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
111 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
112 |
# Reset the cancellation event.
|
113 |
cancel_event.clear()
|
114 |
|
115 |
+
# Prepare internal chat history.
|
116 |
internal_history = list(chat_history) if chat_history else []
|
117 |
internal_history.append({"role": "user", "content": user_message})
|
118 |
|
|
|
136 |
retrieved_context = ""
|
137 |
debug_message = "Web search disabled."
|
138 |
|
139 |
+
# Augment the prompt with search context if available.
|
140 |
if enable_search and retrieved_context:
|
141 |
augmented_user_input = (
|
142 |
f"{system_prompt.strip()}\n\n"
|
|
|
151 |
internal_history.append({"role": "assistant", "content": ""})
|
152 |
|
153 |
try:
|
154 |
+
# Load the model and tokenizer.
|
155 |
model, tokenizer = load_model(model_name)
|
156 |
+
# Move the model to GPU (using .to('cuda')) inside the GPU-decorated function.
|
157 |
+
model = model.to('cuda')
|
158 |
+
# Tokenize the augmented prompt and move input tensors to GPU.
|
159 |
+
input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids.to('cuda')
|
160 |
|
|
|
|
|
161 |
with torch.no_grad():
|
162 |
output_ids = model.generate(
|
163 |
input_ids,
|
|
|
168 |
repetition_penalty=repeat_penalty,
|
169 |
do_sample=True
|
170 |
)
|
|
|
171 |
# Decode the generated tokens.
|
172 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
173 |
+
# Remove the original prompt to isolate the assistant's reply.
|
174 |
assistant_text = generated_text[len(augmented_user_input):].strip()
|
175 |
|
176 |
+
# Simulate streaming output by yielding word-by-word.
|
177 |
words = assistant_text.split()
|
178 |
assistant_message = ""
|
179 |
for word in words:
|
|
|
204 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
205 |
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
|
206 |
gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
|
207 |
+
|
208 |
with gr.Row():
|
209 |
with gr.Column(scale=3):
|
210 |
default_model = list(MODELS.keys())[0] if MODELS else "No models available"
|
|
|
251 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
252 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
253 |
|
254 |
+
# Submission: the chat_response function is now decorated with @spaces.GPU.
|
255 |
msg_input.submit(
|
256 |
fn=chat_response,
|
257 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
wheel
|
2 |
streamlit
|
3 |
duckduckgo_search
|
4 |
-
gradio
|
5 |
-
torch
|
6 |
transformers
|
7 |
spaces
|
8 |
sentencepiece
|
|
|
1 |
wheel
|
2 |
streamlit
|
3 |
duckduckgo_search
|
4 |
+
gradio>=4.0.0
|
5 |
+
torch==2.4.0
|
6 |
transformers
|
7 |
spaces
|
8 |
sentencepiece
|