ajsbsd commited on
Commit
b1f41e6
·
verified ·
1 Parent(s): b96e3f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -86
app.py CHANGED
@@ -1,23 +1,24 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
 
 
 
 
 
 
 
 
4
  from datasets import load_dataset # To get a speaker embedding for TTS
5
  import os
6
  import spaces # Import the spaces library for GPU decorator
7
  import tempfile # For creating temporary audio files
8
  import soundfile as sf # To save audio files
 
9
 
10
  # --- Configuration for Language Model (LLM) ---
11
- # IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID
12
- # rather than a local path ('.'), as the Space will fetch it from the Hub.
13
  HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
14
-
15
- # You might need to adjust TORCH_DTYPE based on your GPU and model support
16
- # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
17
- # For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows.
18
- TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance
19
-
20
- # Generation parameters for the LLM (can be adjusted for different response styles)
21
  MAX_NEW_TOKENS = 512
22
  DO_SAMPLE = True
23
  TEMPERATURE = 0.7
@@ -28,31 +29,36 @@ TOP_P = 0.95
28
  TTS_MODEL_ID = "microsoft/speecht5_tts"
29
  TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
30
 
31
- # --- Global variables for models and tokenizers ---
 
 
 
32
  tokenizer = None
33
- llm_model = None # Renamed to avoid conflict with tts_model
34
  tts_processor = None
35
  tts_model = None
36
  tts_vocoder = None
37
- speaker_embeddings = None # Global for TTS speaker embedding
 
 
38
 
39
- # --- Load Models and Tokenizers Function ---
40
  @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
41
  def load_models():
42
  """
43
- Loads the language model, tokenizer, TTS models, and speaker embeddings
44
- from Hugging Face Hub. This function will be called once when the Gradio app starts up.
 
45
  """
46
  global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
 
47
 
48
- if tokenizer is not None and llm_model is not None and tts_model is not None:
49
- print("All models and tokenizers already loaded.")
 
50
  return
51
 
52
- # When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN
53
- # for public models, but it's good practice for private models or if
54
- # rate limits are hit.
55
- hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set
56
 
57
  # Load Language Model (LLM)
58
  print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
@@ -66,16 +72,13 @@ def load_models():
66
  llm_model = AutoModelForCausalLM.from_pretrained(
67
  HUGGINGFACE_MODEL_ID,
68
  torch_dtype=TORCH_DTYPE,
69
- device_map="auto", # Automatically maps model to GPU if available, else CPU
70
- token=hf_token # Pass token if loading private model
71
  )
72
- llm_model.eval() # Set model to evaluation mode
73
  print("LLM model loaded successfully.")
74
  except Exception as e:
75
  print(f"Error loading LLM model or tokenizer: {e}")
76
- print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.")
77
- tokenizer = None
78
- llm_model = None
79
  raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
80
 
81
  # Load TTS models
@@ -85,14 +88,10 @@ def load_models():
85
  tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
86
  tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
87
 
88
- # Load a speaker embedding (essential for SpeechT5 TTS)
89
- # Using a sample from a public dataset for demonstration
90
  print("Loading speaker embeddings for TTS...")
91
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
92
- # Using a specific speaker embedding (you can experiment with different indices)
93
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
 
95
- # Move TTS components to the same device as the LLM model
96
  device = llm_model.device if llm_model else 'cpu'
97
  tts_model.to(device)
98
  tts_vocoder.to(device)
@@ -101,13 +100,27 @@ def load_models():
101
 
102
  except Exception as e:
103
  print(f"Error loading TTS models or speaker embeddings: {e}")
104
- print("Please ensure TTS model IDs are correct and you have an internet connection.")
105
  tts_processor = None
106
  tts_model = None
107
  tts_vocoder = None
108
  speaker_embeddings = None
109
  raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # --- Generate Response and Audio Function ---
113
  @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
@@ -131,16 +144,13 @@ def generate_response_and_audio(
131
  return history, None
132
 
133
  # --- 1. Generate Text Response (LLM) ---
134
- # Format messages for the model's chat template
135
- messages = history # Use history directly as it's already in the correct format
136
- messages.append({"role": "user", "content": message}) # Add current user message
137
 
138
- # Apply the chat template and tokenize
139
  try:
140
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
  except Exception as e:
142
  print(f"Error applying chat template: {e}")
143
- # Fallback for models without explicit chat templates
144
  input_text = ""
145
  for item in history:
146
  if item["role"] == "user":
@@ -151,8 +161,7 @@ def generate_response_and_audio(
151
 
152
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
153
 
154
- # Generate response
155
- with torch.no_grad(): # Disable gradient calculations for inference
156
  output_ids = llm_model.generate(
157
  input_ids,
158
  max_new_tokens=MAX_NEW_TOKENS,
@@ -160,10 +169,9 @@ def generate_response_and_audio(
160
  temperature=TEMPERATURE,
161
  top_k=TOP_K,
162
  top_p=TOP_P,
163
- pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
164
  )
165
 
166
- # Decode the generated text, excluding the input prompt part
167
  generated_token_ids = output_ids[0][input_ids.shape[-1]:]
168
  generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
169
 
@@ -171,7 +179,6 @@ def generate_response_and_audio(
171
  audio_path = None
172
  if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
173
  try:
174
- # Ensure TTS components are on the correct device
175
  device = llm_model.device if llm_model else 'cpu'
176
  tts_model.to(device)
177
  tts_vocoder.to(device)
@@ -180,80 +187,141 @@ def generate_response_and_audio(
180
  tts_inputs = tts_processor(
181
  text=generated_text,
182
  return_tensors="pt",
183
- max_length=550, # Set a max length to prevent excessively long audio
184
- truncation=True # Enable truncation if text exceeds max_length
185
  ).to(device)
186
 
187
  with torch.no_grad():
188
  speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
189
 
190
- # Create a temporary file to save the audio
191
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
192
  audio_path = tmp_file.name
193
- # Ensure audio data is on CPU before saving with soundfile
194
  sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
195
  print(f"Audio saved to: {audio_path}")
196
 
197
  except Exception as e:
198
  print(f"Error generating audio: {e}")
199
- audio_path = None # Return None if audio generation fails
200
  else:
201
  print("TTS components not loaded. Skipping audio generation.")
202
 
203
-
204
  # --- 3. Update Chat History ---
205
- # Append the latest generated response to the history with its role
206
  history.append({"role": "assistant", "content": generated_text})
207
 
208
  return history, audio_path
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  # --- Gradio Interface ---
211
  with gr.Blocks() as demo:
212
  gr.Markdown(
213
  """
214
- # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot
215
- Type your message below and chat with the model!
 
216
  """
217
  )
218
 
219
- # Set type='messages' for the chatbot to use OpenAI-style dictionaries
220
- chatbot = gr.Chatbot(label="Conversation", type='messages')
221
- with gr.Row():
222
- text_input = gr.Textbox(
223
- label="Your message",
224
- placeholder="Type your message here...",
225
- scale=4
 
 
 
 
 
 
 
226
  )
227
- submit_button = gr.Button("Send", scale=1)
228
-
229
- audio_output = gr.Audio(
230
- label="Listen to Response",
231
- autoplay=True, # Automatically play audio
232
- interactive=False # Don't allow user to interact with this audio component
233
- )
234
 
235
- # Link the text input and button to the generation function
236
- # Outputs now include both the chatbot history and the audio file path
237
- submit_button.click(
238
- fn=generate_response_and_audio,
239
- inputs=[text_input, chatbot],
240
- outputs=[chatbot, audio_output],
241
- queue=True # Queue requests for better concurrency
242
- )
243
- text_input.submit( # Also trigger on Enter key
244
- fn=generate_response_and_audio,
245
- inputs=[text_input, chatbot],
246
- outputs=[chatbot, audio_output],
247
- queue=True
248
- )
249
 
250
- # Clear button
251
- def clear_chat():
252
- # Clear history, text input, and audio output
253
- return [], "", None
254
- clear_button = gr.Button("Clear Chat")
255
- clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
 
 
 
 
 
 
 
 
 
257
 
258
  # Load all models when the app starts up
259
  load_models()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ SpeechT5Processor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ WhisperProcessor, # New: For Speech-to-Text
10
+ WhisperForConditionalGeneration # New: For Speech-to-Text
11
+ )
12
  from datasets import load_dataset # To get a speaker embedding for TTS
13
  import os
14
  import spaces # Import the spaces library for GPU decorator
15
  import tempfile # For creating temporary audio files
16
  import soundfile as sf # To save audio files
17
+ import librosa # New: For loading audio files for transcription
18
 
19
  # --- Configuration for Language Model (LLM) ---
 
 
20
  HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
+ TORCH_DTYPE = torch.bfloat16
 
 
 
 
 
 
22
  MAX_NEW_TOKENS = 512
23
  DO_SAMPLE = True
24
  TEMPERATURE = 0.7
 
29
  TTS_MODEL_ID = "microsoft/speecht5_tts"
30
  TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
31
 
32
+ # --- Configuration for Speech-to-Text (STT) ---
33
+ STT_MODEL_ID = "openai/whisper-tiny" # Using a smaller Whisper model for faster inference
34
+
35
+ # --- Global variables for models and tokenizers/processors ---
36
  tokenizer = None
37
+ llm_model = None
38
  tts_processor = None
39
  tts_model = None
40
  tts_vocoder = None
41
+ speaker_embeddings = None
42
+ whisper_processor = None # New: Global for Whisper processor
43
+ whisper_model = None # New: Global for Whisper model
44
 
45
+ # --- Load All Models Function ---
46
  @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
47
  def load_models():
48
  """
49
+ Loads the language model, tokenizer, TTS models, speaker embeddings,
50
+ and STT (Whisper) models from Hugging Face Hub.
51
+ This function will be called once when the Gradio app starts up.
52
  """
53
  global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
54
+ global whisper_processor, whisper_model
55
 
56
+ if (tokenizer is not None and llm_model is not None and tts_model is not None and
57
+ whisper_processor is not None and whisper_model is not None):
58
+ print("All models and tokenizers/processors already loaded.")
59
  return
60
 
61
+ hf_token = os.environ.get("HF_TOKEN")
 
 
 
62
 
63
  # Load Language Model (LLM)
64
  print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
 
72
  llm_model = AutoModelForCausalLM.from_pretrained(
73
  HUGGINGFACE_MODEL_ID,
74
  torch_dtype=TORCH_DTYPE,
75
+ device_map="auto",
76
+ token=hf_token
77
  )
78
+ llm_model.eval()
79
  print("LLM model loaded successfully.")
80
  except Exception as e:
81
  print(f"Error loading LLM model or tokenizer: {e}")
 
 
 
82
  raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
83
 
84
  # Load TTS models
 
88
  tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
89
  tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
90
 
 
 
91
  print("Loading speaker embeddings for TTS...")
92
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
 
93
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
 
 
95
  device = llm_model.device if llm_model else 'cpu'
96
  tts_model.to(device)
97
  tts_vocoder.to(device)
 
100
 
101
  except Exception as e:
102
  print(f"Error loading TTS models or speaker embeddings: {e}")
 
103
  tts_processor = None
104
  tts_model = None
105
  tts_vocoder = None
106
  speaker_embeddings = None
107
  raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
108
 
109
+ # Load STT (Whisper) model
110
+ print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}")
111
+ try:
112
+ whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
113
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
114
+
115
+ device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM
116
+ whisper_model.to(device)
117
+ print(f"STT (Whisper) model loaded successfully to device: {device}.")
118
+ except Exception as e:
119
+ print(f"Error loading STT (Whisper) model or processor: {e}")
120
+ whisper_processor = None
121
+ whisper_model = None
122
+ raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.")
123
+
124
 
125
  # --- Generate Response and Audio Function ---
126
  @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
 
144
  return history, None
145
 
146
  # --- 1. Generate Text Response (LLM) ---
147
+ messages = history
148
+ messages.append({"role": "user", "content": message})
 
149
 
 
150
  try:
151
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
  except Exception as e:
153
  print(f"Error applying chat template: {e}")
 
154
  input_text = ""
155
  for item in history:
156
  if item["role"] == "user":
 
161
 
162
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
163
 
164
+ with torch.no_grad():
 
165
  output_ids = llm_model.generate(
166
  input_ids,
167
  max_new_tokens=MAX_NEW_TOKENS,
 
169
  temperature=TEMPERATURE,
170
  top_k=TOP_K,
171
  top_p=TOP_P,
172
+ pad_token_id=tokenizer.eos_token_id
173
  )
174
 
 
175
  generated_token_ids = output_ids[0][input_ids.shape[-1]:]
176
  generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
177
 
 
179
  audio_path = None
180
  if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
181
  try:
 
182
  device = llm_model.device if llm_model else 'cpu'
183
  tts_model.to(device)
184
  tts_vocoder.to(device)
 
187
  tts_inputs = tts_processor(
188
  text=generated_text,
189
  return_tensors="pt",
190
+ max_length=550,
191
+ truncation=True
192
  ).to(device)
193
 
194
  with torch.no_grad():
195
  speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
196
 
 
197
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
198
  audio_path = tmp_file.name
 
199
  sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
200
  print(f"Audio saved to: {audio_path}")
201
 
202
  except Exception as e:
203
  print(f"Error generating audio: {e}")
204
+ audio_path = None
205
  else:
206
  print("TTS components not loaded. Skipping audio generation.")
207
 
 
208
  # --- 3. Update Chat History ---
 
209
  history.append({"role": "assistant", "content": generated_text})
210
 
211
  return history, audio_path
212
 
213
+
214
+ # --- Transcribe Audio Function (NEW) ---
215
+ @spaces.GPU # This function also needs GPU access for Whisper inference
216
+ def transcribe_audio(audio_filepath):
217
+ """
218
+ Transcribes an audio file using the loaded Whisper model.
219
+ Handles audio files of varying lengths.
220
+ """
221
+ global whisper_processor, whisper_model
222
+
223
+ if whisper_processor is None or whisper_model is None:
224
+ load_models() # Attempt to load if not already loaded
225
+
226
+ if whisper_processor is None or whisper_model is None:
227
+ return "Error: Speech-to-Text model not loaded. Please check logs."
228
+
229
+ if audio_filepath is None:
230
+ return "No audio input provided for transcription."
231
+
232
+ print(f"Transcribing audio from: {audio_filepath}")
233
+ try:
234
+ # Load audio file and resample to 16kHz (Whisper's required sample rate)
235
+ audio, sample_rate = librosa.load(audio_filepath, sr=16000)
236
+
237
+ # Process audio input for the Whisper model
238
+ input_features = whisper_processor(
239
+ audio,
240
+ sampling_rate=sample_rate,
241
+ return_tensors="pt"
242
+ ).input_features.to(whisper_model.device)
243
+
244
+ # Generate transcription IDs
245
+ predicted_ids = whisper_model.generate(input_features)
246
+
247
+ # Decode the IDs to text
248
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
249
+ print(f"Transcription: {transcription}")
250
+ return transcription
251
+
252
+ except Exception as e:
253
+ print(f"Error during transcription: {e}")
254
+ return f"Transcription failed: {e}"
255
+
256
+
257
  # --- Gradio Interface ---
258
  with gr.Blocks() as demo:
259
  gr.Markdown(
260
  """
261
+ # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output
262
+ Type your message or speak into the microphone to chat with the model.
263
+ The chatbot's response will be spoken, and your audio input can be transcribed!
264
  """
265
  )
266
 
267
+ with gr.Tab("Chat with Voice"):
268
+ chatbot = gr.Chatbot(label="Conversation", type='messages')
269
+ with gr.Row():
270
+ text_input = gr.Textbox(
271
+ label="Your message",
272
+ placeholder="Type your message here...",
273
+ scale=4
274
+ )
275
+ submit_button = gr.Button("Send", scale=1)
276
+
277
+ audio_output = gr.Audio(
278
+ label="Listen to Response",
279
+ autoplay=True,
280
+ interactive=False
281
  )
 
 
 
 
 
 
 
282
 
283
+ submit_button.click(
284
+ fn=generate_response_and_audio,
285
+ inputs=[text_input, chatbot],
286
+ outputs=[chatbot, audio_output],
287
+ queue=True
288
+ )
289
+ text_input.submit(
290
+ fn=generate_response_and_audio,
291
+ inputs=[text_input, chatbot],
292
+ outputs=[chatbot, audio_output],
293
+ queue=True
294
+ )
 
 
295
 
296
+ with gr.Tab("Audio Transcription"):
297
+ stt_audio_input = gr.Audio(
298
+ type="filepath",
299
+ label="Upload Audio or Record from Microphone",
300
+ source="microphone", # Can be "microphone" or "upload" or ["microphone", "upload"]
301
+ format="wav" # Ensure consistent format
302
+ )
303
+ transcribe_button = gr.Button("Transcribe Audio")
304
+ transcribed_text_output = gr.Textbox(
305
+ label="Transcription",
306
+ placeholder="Transcription will appear here...",
307
+ interactive=False
308
+ )
309
+ transcribe_button.click(
310
+ fn=transcribe_audio,
311
+ inputs=[stt_audio_input],
312
+ outputs=[transcribed_text_output],
313
+ queue=True
314
+ )
315
 
316
+ # Clear button for the entire interface
317
+ def clear_all():
318
+ return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output
319
+ clear_button = gr.Button("Clear All")
320
+ clear_button.click(
321
+ clear_all,
322
+ inputs=None,
323
+ outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output]
324
+ )
325
 
326
  # Load all models when the app starts up
327
  load_models()