Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,91 +1,137 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
4 |
import os
|
5 |
-
import spaces
|
|
|
|
|
6 |
|
7 |
-
# --- Configuration ---
|
8 |
-
# IMPORTANT:
|
9 |
-
#
|
10 |
-
# LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct"
|
11 |
-
# HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing
|
12 |
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
|
13 |
|
14 |
# You might need to adjust TORCH_DTYPE based on your GPU and model support
|
15 |
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
|
16 |
-
|
|
|
17 |
|
18 |
-
# Generation parameters (can be adjusted for different response styles)
|
19 |
MAX_NEW_TOKENS = 512
|
20 |
DO_SAMPLE = True
|
21 |
TEMPERATURE = 0.7
|
22 |
TOP_K = 50
|
23 |
TOP_P = 0.95
|
24 |
|
|
|
|
|
|
|
|
|
25 |
# --- Global variables for models and tokenizers ---
|
26 |
tokenizer = None
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# --- Load Models and Tokenizers Function ---
|
30 |
-
@spaces.GPU
|
31 |
-
def
|
32 |
"""
|
33 |
-
Loads the language model
|
34 |
-
This function will be called once when the Gradio app starts up.
|
35 |
"""
|
36 |
-
global tokenizer,
|
37 |
|
38 |
-
if tokenizer is not None and
|
39 |
-
print("
|
40 |
return
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
try:
|
44 |
-
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID)
|
45 |
if tokenizer.pad_token is None:
|
46 |
tokenizer.pad_token = tokenizer.eos_token
|
47 |
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
|
48 |
|
49 |
-
print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...")
|
50 |
-
|
51 |
HUGGINGFACE_MODEL_ID,
|
52 |
torch_dtype=TORCH_DTYPE,
|
53 |
-
device_map="auto" # Automatically maps model to GPU if available, else CPU
|
|
|
54 |
)
|
55 |
-
|
56 |
-
print("
|
57 |
except Exception as e:
|
58 |
-
print(f"Error loading model or tokenizer: {e}")
|
59 |
-
print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.")
|
60 |
tokenizer = None
|
61 |
-
|
62 |
-
raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
message: str, # Current user message
|
69 |
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
|
70 |
-
) ->
|
71 |
"""
|
72 |
-
Generates a text response from the loaded
|
|
|
73 |
"""
|
74 |
-
global tokenizer,
|
75 |
|
76 |
-
# Initialize models if not already loaded
|
77 |
-
if tokenizer is None or
|
78 |
-
|
79 |
|
80 |
-
if tokenizer is None or
|
81 |
-
# history.append([message, "Error: Chatbot model not loaded. Please check logs."])
|
82 |
-
# For 'messages' type history, append a dictionary
|
83 |
history.append({"role": "user", "content": message})
|
84 |
-
history.append({"role": "assistant", "content": "Error: Chatbot
|
85 |
-
return history
|
86 |
|
87 |
-
#
|
88 |
-
#
|
89 |
messages = history # Use history directly as it's already in the correct format
|
90 |
messages.append({"role": "user", "content": message}) # Add current user message
|
91 |
|
@@ -94,8 +140,7 @@ def generate_response(
|
|
94 |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
95 |
except Exception as e:
|
96 |
print(f"Error applying chat template: {e}")
|
97 |
-
# Fallback
|
98 |
-
# Reconstruct input_text for models without explicit chat templates
|
99 |
input_text = ""
|
100 |
for item in history:
|
101 |
if item["role"] == "user":
|
@@ -104,11 +149,11 @@ def generate_response(
|
|
104 |
input_text += f"Assistant: {item['content']}\n"
|
105 |
input_text += f"User: {message}\nAssistant:"
|
106 |
|
107 |
-
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(
|
108 |
|
109 |
# Generate response
|
110 |
with torch.no_grad(): # Disable gradient calculations for inference
|
111 |
-
output_ids =
|
112 |
input_ids,
|
113 |
max_new_tokens=MAX_NEW_TOKENS,
|
114 |
do_sample=DO_SAMPLE,
|
@@ -122,11 +167,45 @@ def generate_response(
|
|
122 |
generated_token_ids = output_ids[0][input_ids.shape[-1]:]
|
123 |
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
|
124 |
|
125 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
# Append the latest generated response to the history with its role
|
127 |
history.append({"role": "assistant", "content": generated_text})
|
128 |
|
129 |
-
return history
|
130 |
|
131 |
# --- Gradio Interface ---
|
132 |
with gr.Blocks() as demo:
|
@@ -147,34 +226,37 @@ with gr.Blocks() as demo:
|
|
147 |
)
|
148 |
submit_button = gr.Button("Send", scale=1)
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
# Link the text input and button to the generation function
|
151 |
-
#
|
152 |
-
# 'outputs' will be the updated full history
|
153 |
submit_button.click(
|
154 |
-
fn=
|
155 |
-
inputs=[text_input, chatbot],
|
156 |
-
outputs=[chatbot],
|
157 |
queue=True # Queue requests for better concurrency
|
158 |
)
|
159 |
text_input.submit( # Also trigger on Enter key
|
160 |
-
fn=
|
161 |
inputs=[text_input, chatbot],
|
162 |
-
outputs=[chatbot],
|
163 |
queue=True
|
164 |
)
|
165 |
|
166 |
# Clear button
|
167 |
def clear_chat():
|
168 |
-
#
|
169 |
-
|
170 |
-
return [], ""
|
171 |
clear_button = gr.Button("Clear Chat")
|
172 |
-
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input])
|
173 |
|
174 |
|
175 |
-
# Load
|
176 |
-
|
177 |
|
178 |
# Launch the Gradio app
|
179 |
-
|
180 |
-
demo.queue().launch(server_name="0.0.0.0")
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
|
4 |
+
from datasets import load_dataset # To get a speaker embedding for TTS
|
5 |
import os
|
6 |
+
import spaces # Import the spaces library for GPU decorator
|
7 |
+
import tempfile # For creating temporary audio files
|
8 |
+
import soundfile as sf # To save audio files
|
9 |
|
10 |
+
# --- Configuration for Language Model (LLM) ---
|
11 |
+
# IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID
|
12 |
+
# rather than a local path ('.'), as the Space will fetch it from the Hub.
|
|
|
|
|
13 |
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
|
14 |
|
15 |
# You might need to adjust TORCH_DTYPE based on your GPU and model support
|
16 |
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
|
17 |
+
# For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows.
|
18 |
+
TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance
|
19 |
|
20 |
+
# Generation parameters for the LLM (can be adjusted for different response styles)
|
21 |
MAX_NEW_TOKENS = 512
|
22 |
DO_SAMPLE = True
|
23 |
TEMPERATURE = 0.7
|
24 |
TOP_K = 50
|
25 |
TOP_P = 0.95
|
26 |
|
27 |
+
# --- Configuration for Text-to-Speech (TTS) ---
|
28 |
+
TTS_MODEL_ID = "microsoft/speecht5_tts"
|
29 |
+
TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
|
30 |
+
|
31 |
# --- Global variables for models and tokenizers ---
|
32 |
tokenizer = None
|
33 |
+
llm_model = None # Renamed to avoid conflict with tts_model
|
34 |
+
tts_processor = None
|
35 |
+
tts_model = None
|
36 |
+
tts_vocoder = None
|
37 |
+
speaker_embeddings = None # Global for TTS speaker embedding
|
38 |
|
39 |
# --- Load Models and Tokenizers Function ---
|
40 |
+
@spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
|
41 |
+
def load_models():
|
42 |
"""
|
43 |
+
Loads the language model, tokenizer, TTS models, and speaker embeddings
|
44 |
+
from Hugging Face Hub. This function will be called once when the Gradio app starts up.
|
45 |
"""
|
46 |
+
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
|
47 |
|
48 |
+
if tokenizer is not None and llm_model is not None and tts_model is not None:
|
49 |
+
print("All models and tokenizers already loaded.")
|
50 |
return
|
51 |
|
52 |
+
# When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN
|
53 |
+
# for public models, but it's good practice for private models or if
|
54 |
+
# rate limits are hit.
|
55 |
+
hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set
|
56 |
+
|
57 |
+
# Load Language Model (LLM)
|
58 |
+
print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
|
59 |
try:
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
|
61 |
if tokenizer.pad_token is None:
|
62 |
tokenizer.pad_token = tokenizer.eos_token
|
63 |
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
|
64 |
|
65 |
+
print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
|
66 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
67 |
HUGGINGFACE_MODEL_ID,
|
68 |
torch_dtype=TORCH_DTYPE,
|
69 |
+
device_map="auto", # Automatically maps model to GPU if available, else CPU
|
70 |
+
token=hf_token # Pass token if loading private model
|
71 |
)
|
72 |
+
llm_model.eval() # Set model to evaluation mode
|
73 |
+
print("LLM model loaded successfully.")
|
74 |
except Exception as e:
|
75 |
+
print(f"Error loading LLM model or tokenizer: {e}")
|
76 |
+
print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.")
|
77 |
tokenizer = None
|
78 |
+
llm_model = None
|
79 |
+
raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
|
80 |
+
|
81 |
+
# Load TTS models
|
82 |
+
print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
|
83 |
+
try:
|
84 |
+
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
|
85 |
+
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
|
86 |
+
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
|
87 |
+
|
88 |
+
# Load a speaker embedding (essential for SpeechT5 TTS)
|
89 |
+
# Using a sample from a public dataset for demonstration
|
90 |
+
print("Loading speaker embeddings for TTS...")
|
91 |
+
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
|
92 |
+
# Using a specific speaker embedding (you can experiment with different indices)
|
93 |
+
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
94 |
|
95 |
+
# Move TTS components to the same device as the LLM model
|
96 |
+
device = llm_model.device if llm_model else 'cpu'
|
97 |
+
tts_model.to(device)
|
98 |
+
tts_vocoder.to(device)
|
99 |
+
speaker_embeddings = speaker_embeddings.to(device)
|
100 |
+
print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
|
101 |
|
102 |
+
except Exception as e:
|
103 |
+
print(f"Error loading TTS models or speaker embeddings: {e}")
|
104 |
+
print("Please ensure TTS model IDs are correct and you have an internet connection.")
|
105 |
+
tts_processor = None
|
106 |
+
tts_model = None
|
107 |
+
tts_vocoder = None
|
108 |
+
speaker_embeddings = None
|
109 |
+
raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
|
110 |
+
|
111 |
+
|
112 |
+
# --- Generate Response and Audio Function ---
|
113 |
+
@spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
|
114 |
+
def generate_response_and_audio(
|
115 |
message: str, # Current user message
|
116 |
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
|
117 |
+
) -> tuple: # Returns (updated_history, audio_file_path)
|
118 |
"""
|
119 |
+
Generates a text response from the loaded LLM and then converts it to audio
|
120 |
+
using the loaded TTS model.
|
121 |
"""
|
122 |
+
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
|
123 |
|
124 |
+
# Initialize all models if not already loaded
|
125 |
+
if tokenizer is None or llm_model is None or tts_model is None:
|
126 |
+
load_models()
|
127 |
|
128 |
+
if tokenizer is None or llm_model is None: # Check LLM loading status
|
|
|
|
|
129 |
history.append({"role": "user", "content": message})
|
130 |
+
history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
|
131 |
+
return history, None
|
132 |
|
133 |
+
# --- 1. Generate Text Response (LLM) ---
|
134 |
+
# Format messages for the model's chat template
|
135 |
messages = history # Use history directly as it's already in the correct format
|
136 |
messages.append({"role": "user", "content": message}) # Add current user message
|
137 |
|
|
|
140 |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
141 |
except Exception as e:
|
142 |
print(f"Error applying chat template: {e}")
|
143 |
+
# Fallback for models without explicit chat templates
|
|
|
144 |
input_text = ""
|
145 |
for item in history:
|
146 |
if item["role"] == "user":
|
|
|
149 |
input_text += f"Assistant: {item['content']}\n"
|
150 |
input_text += f"User: {message}\nAssistant:"
|
151 |
|
152 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
|
153 |
|
154 |
# Generate response
|
155 |
with torch.no_grad(): # Disable gradient calculations for inference
|
156 |
+
output_ids = llm_model.generate(
|
157 |
input_ids,
|
158 |
max_new_tokens=MAX_NEW_TOKENS,
|
159 |
do_sample=DO_SAMPLE,
|
|
|
167 |
generated_token_ids = output_ids[0][input_ids.shape[-1]:]
|
168 |
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
|
169 |
|
170 |
+
# --- 2. Generate Audio from Response (TTS) ---
|
171 |
+
audio_path = None
|
172 |
+
if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
|
173 |
+
try:
|
174 |
+
# Ensure TTS components are on the correct device
|
175 |
+
device = llm_model.device if llm_model else 'cpu'
|
176 |
+
tts_model.to(device)
|
177 |
+
tts_vocoder.to(device)
|
178 |
+
speaker_embeddings = speaker_embeddings.to(device)
|
179 |
+
|
180 |
+
tts_inputs = tts_processor(
|
181 |
+
text=generated_text,
|
182 |
+
return_tensors="pt",
|
183 |
+
max_length=550, # Set a max length to prevent excessively long audio
|
184 |
+
truncation=True # Enable truncation if text exceeds max_length
|
185 |
+
).to(device)
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
|
189 |
+
|
190 |
+
# Create a temporary file to save the audio
|
191 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
192 |
+
audio_path = tmp_file.name
|
193 |
+
# Ensure audio data is on CPU before saving with soundfile
|
194 |
+
sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
|
195 |
+
print(f"Audio saved to: {audio_path}")
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error generating audio: {e}")
|
199 |
+
audio_path = None # Return None if audio generation fails
|
200 |
+
else:
|
201 |
+
print("TTS components not loaded. Skipping audio generation.")
|
202 |
+
|
203 |
+
|
204 |
+
# --- 3. Update Chat History ---
|
205 |
# Append the latest generated response to the history with its role
|
206 |
history.append({"role": "assistant", "content": generated_text})
|
207 |
|
208 |
+
return history, audio_path
|
209 |
|
210 |
# --- Gradio Interface ---
|
211 |
with gr.Blocks() as demo:
|
|
|
226 |
)
|
227 |
submit_button = gr.Button("Send", scale=1)
|
228 |
|
229 |
+
audio_output = gr.Audio(
|
230 |
+
label="Listen to Response",
|
231 |
+
autoplay=True, # Automatically play audio
|
232 |
+
interactive=False # Don't allow user to interact with this audio component
|
233 |
+
)
|
234 |
+
|
235 |
# Link the text input and button to the generation function
|
236 |
+
# Outputs now include both the chatbot history and the audio file path
|
|
|
237 |
submit_button.click(
|
238 |
+
fn=generate_response_and_audio,
|
239 |
+
inputs=[text_input, chatbot],
|
240 |
+
outputs=[chatbot, audio_output],
|
241 |
queue=True # Queue requests for better concurrency
|
242 |
)
|
243 |
text_input.submit( # Also trigger on Enter key
|
244 |
+
fn=generate_response_and_audio,
|
245 |
inputs=[text_input, chatbot],
|
246 |
+
outputs=[chatbot, audio_output],
|
247 |
queue=True
|
248 |
)
|
249 |
|
250 |
# Clear button
|
251 |
def clear_chat():
|
252 |
+
# Clear history, text input, and audio output
|
253 |
+
return [], "", None
|
|
|
254 |
clear_button = gr.Button("Clear Chat")
|
255 |
+
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output])
|
256 |
|
257 |
|
258 |
+
# Load all models when the app starts up
|
259 |
+
load_models()
|
260 |
|
261 |
# Launch the Gradio app
|
262 |
+
demo.queue().launch()
|
|