ajsbsd commited on
Commit
266db90
·
verified ·
1 Parent(s): add83be

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -366
app.py DELETED
@@ -1,366 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- SpeechT5Processor,
7
- SpeechT5ForTextToSpeech,
8
- SpeechT5HifiGan,
9
- WhisperProcessor, # For Speech-to-Text
10
- WhisperForConditionalGeneration # For Speech-to-Text
11
- )
12
- from datasets import load_dataset # To get a speaker embedding for TTS
13
- import os
14
- import spaces # Import the spaces library for GPU decorator
15
- import tempfile # For creating temporary audio files
16
- import soundfile as sf # To save audio files
17
- import librosa # For loading audio files for transcription
18
-
19
- # --- Configuration for Language Model (LLM) ---
20
- HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
- TORCH_DTYPE = torch.bfloat16
22
- MAX_NEW_TOKENS = 512
23
- DO_SAMPLE = True
24
- TEMPERATURE = 0.7
25
- TOP_K = 50
26
- TOP_P = 0.95
27
-
28
- # --- Configuration for Text-to-Speech (TTS) ---
29
- TTS_MODEL_ID = "microsoft/speecht5_tts"
30
- TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
31
-
32
- # --- Configuration for Speech-to-Text (STT) ---
33
- STT_MODEL_ID = "openai/whisper-small" # Changed from 'openai/whisper-tiny' for better long audio transcription
34
-
35
- # --- Global variables for models and tokenizers/processors ---
36
- tokenizer = None
37
- llm_model = None
38
- tts_processor = None
39
- tts_model = None
40
- tts_vocoder = None
41
- speaker_embeddings = None
42
- whisper_processor = None
43
- whisper_model = None
44
-
45
- # --- Load All Models Function ---
46
- @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
47
- def load_models():
48
- """
49
- Loads the language model, tokenizer, TTS models, speaker embeddings,
50
- and STT (Whisper) models from Hugging Face Hub.
51
- This function will be called once when the Gradio app starts up.
52
- """
53
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
54
- global whisper_processor, whisper_model
55
-
56
- if (tokenizer is not None and llm_model is not None and tts_model is not None and
57
- whisper_processor is not None and whisper_model is not None):
58
- print("All models and tokenizers/processors already loaded.")
59
- return
60
-
61
- hf_token = os.environ.get("HF_TOKEN")
62
-
63
- # Load Language Model (LLM)
64
- print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
65
- try:
66
- tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
67
- if tokenizer.pad_token is None:
68
- tokenizer.pad_token = tokenizer.eos_token
69
- print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
70
-
71
- print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
72
- llm_model = AutoModelForCausalLM.from_pretrained(
73
- HUGGINGFACE_MODEL_ID,
74
- torch_dtype=TORCH_DTYPE,
75
- device_map="auto",
76
- token=hf_token
77
- )
78
- llm_model.eval()
79
- print("LLM model loaded successfully.")
80
- except Exception as e:
81
- print(f"Error loading LLM model or tokenizer: {e}")
82
- raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
83
-
84
- # Load TTS models
85
- print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
86
- try:
87
- tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
88
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
89
- tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
90
-
91
- print("Loading speaker embeddings for TTS...")
92
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
93
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
-
95
- device = llm_model.device if llm_model else 'cpu'
96
- tts_model.to(device)
97
- tts_vocoder.to(device)
98
- speaker_embeddings = speaker_embeddings.to(device)
99
- print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
100
-
101
- except Exception as e:
102
- print(f"Error loading TTS models or speaker embeddings: {e}")
103
- tts_processor = None
104
- tts_model = None
105
- tts_vocoder = None
106
- speaker_embeddings = None
107
- raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
108
-
109
- # Load STT (Whisper) model
110
- print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}")
111
- try:
112
- whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
113
- whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
114
-
115
- device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM
116
- whisper_model.to(device)
117
- print(f"STT (Whisper) model loaded successfully to device: {device}.")
118
- except Exception as e:
119
- print(f"Error loading STT (Whisper) model or processor: {e}")
120
- whisper_processor = None
121
- whisper_model = None
122
- raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.")
123
-
124
-
125
- # --- Generate Response and Audio Function ---
126
- @spaces.GPU
127
- def generate_response_and_audio(message: str, history: list) -> tuple:
128
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
129
-
130
- if tokenizer is None or llm_model is None or tts_model is None:
131
- load_models()
132
-
133
- if tokenizer is None or llm_model is None:
134
- history.append({"role": "user", "content": message})
135
- history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
136
- return history, None
137
-
138
- # Initialize generated_text early
139
- generated_text = ""
140
-
141
- # --- 1. Generate Text Response (LLM) ---
142
- messages = history.copy()
143
- messages.append({"role": "user", "content": message})
144
-
145
- try:
146
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
147
- except Exception as e:
148
- print(f"Error applying chat template: {e}")
149
- input_text = ""
150
- for item in history:
151
- input_text += f"{item['role'].capitalize()}: {item['content']}\n"
152
- input_text += f"User: {message}\nAssistant:"
153
-
154
- try:
155
- input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
156
- with torch.no_grad():
157
- output_ids = llm_model.generate(
158
- input_ids["input_ids"],
159
- attention_mask=input_ids["attention_mask"],
160
- max_new_tokens=MAX_NEW_TOKENS,
161
- do_sample=DO_SAMPLE,
162
- temperature=TEMPERATURE,
163
- top_k=TOP_K,
164
- top_p=TOP_P,
165
- pad_token_id=tokenizer.eos_token_id
166
- )
167
-
168
- generated_token_ids = output_ids[0][input_ids["input_ids"].shape[-1]:]
169
- generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
170
-
171
- except Exception as e:
172
- print(f"Error during LLM generation: {e}")
173
- history.append({"role": "assistant", "content": "I encountered an error while generating a response."})
174
- return history, None
175
-
176
- # --- 2. Generate Audio from Response (TTS) ---
177
- audio_path = None
178
- if all([tts_processor, tts_model, tts_vocoder, speaker_embeddings]):
179
- try:
180
- device = llm_model.device if llm_model else 'cpu'
181
- tts_model.to(device)
182
- tts_vocoder.to(device)
183
- speaker_embeddings = speaker_embeddings.to(device)
184
-
185
- tts_inputs = tts_processor(
186
- text=generated_text,
187
- return_tensors="pt",
188
- max_length=550,
189
- truncation=True
190
- ).to(device)
191
-
192
- with torch.no_grad():
193
- speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
194
-
195
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
196
- audio_path = tmp_file.name
197
- sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
198
-
199
- print(f"Audio saved to: {audio_path}")
200
-
201
- except Exception as e:
202
- print(f"Error generating audio: {e}")
203
- audio_path = None
204
- else:
205
- print("TTS components not fully loaded. Skipping audio generation.")
206
-
207
- # --- 3. Update Chat History ---
208
- history.append({"role": "assistant", "content": generated_text})
209
- return history, audio_path
210
- # --- 2. Generate Audio from Response (TTS) ---
211
- audio_path = None
212
- if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
213
- try:
214
- device = llm_model.device if llm_model else 'cpu'
215
- tts_model.to(device)
216
- tts_vocoder.to(device)
217
- speaker_embeddings = speaker_embeddings.to(device)
218
-
219
- tts_inputs = tts_processor(
220
- text=generated_text,
221
- return_tensors="pt",
222
- max_length=550,
223
- truncation=True
224
- ).to(device)
225
-
226
- with torch.no_grad():
227
- speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
228
-
229
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
230
- audio_path = tmp_file.name
231
- sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
232
- print(f"Audio saved to: {audio_path}")
233
-
234
- except Exception as e:
235
- print(f"Error generating audio: {e}")
236
- audio_path = None
237
- else:
238
- print("TTS components not loaded. Skipping audio generation.")
239
-
240
- # --- 3. Update Chat History ---
241
- history.append({"role": "assistant", "content": generated_text})
242
-
243
- return history, audio_path
244
-
245
-
246
- # --- Transcribe Audio Function (NEW) ---
247
- @spaces.GPU # This function also needs GPU access for Whisper inference
248
- def transcribe_audio(audio_filepath):
249
- """
250
- Transcribes an audio file using the loaded Whisper model.
251
- Handles audio files of varying lengths.
252
- """
253
- global whisper_processor, whisper_model
254
-
255
- if whisper_processor is None or whisper_model is None:
256
- load_models() # Attempt to load if not already loaded
257
-
258
- if whisper_processor is None or whisper_model is None:
259
- return "Error: Speech-to-Text model not loaded. Please check logs."
260
-
261
- if audio_filepath is None:
262
- return "No audio input provided for transcription."
263
-
264
- print(f"Transcribing audio from: {audio_filepath}")
265
- try:
266
- # Load audio file and resample to 16kHz (Whisper's required sample rate)
267
- audio, sample_rate = librosa.load(audio_filepath, sr=16000)
268
-
269
- # Process audio input for the Whisper model
270
- # The Whisper `generate` method, especially with larger models, is designed
271
- # to handle variable-length inputs by internally managing context.
272
- input_features = whisper_processor(
273
- audio,
274
- sampling_rate=sample_rate,
275
- return_tensors="pt"
276
- ).input_features.to(whisper_model.device)
277
-
278
- # Generate transcription IDs
279
- predicted_ids = whisper_model.generate(input_features)
280
-
281
- # Decode the IDs to text
282
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
283
- print(f"Transcription: {transcription}")
284
- return transcription
285
-
286
- except Exception as e:
287
- print(f"Error during transcription: {e}")
288
- return f"Transcription failed: {e}"
289
-
290
-
291
- # --- Gradio Interface ---
292
- with gr.Blocks() as demo:
293
- gr.Markdown(
294
- """
295
- # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output
296
- Type your message or speak into the microphone to chat with the model.
297
- The chatbot's response will be spoken, and your audio input can be transcribed!
298
- """
299
- )
300
-
301
- with gr.Tab("Chat with Voice"):
302
- chatbot = gr.Chatbot(label="Conversation", type='messages')
303
- with gr.Row():
304
- text_input = gr.Textbox(
305
- label="Your message",
306
- placeholder="Type your message here...",
307
- scale=4
308
- )
309
- submit_button = gr.Button("Send", scale=1)
310
-
311
- audio_output = gr.Audio(
312
- label="Listen to Response",
313
- autoplay=True,
314
- interactive=False
315
- )
316
-
317
- submit_button.click(
318
- fn=generate_response_and_audio,
319
- inputs=[text_input, chatbot],
320
- outputs=[chatbot, audio_output],
321
- queue=True
322
- )
323
- text_input.submit(
324
- fn=generate_response_and_audio,
325
- inputs=[text_input, chatbot],
326
- outputs=[chatbot, audio_output],
327
- queue=True
328
- )
329
-
330
- with gr.Tab("Audio Transcription"):
331
- stt_audio_input = gr.Audio(
332
- type="filepath",
333
- label="Upload Audio or Record from Microphone",
334
- # Removed 'microphone=True' and 'source' as they cause TypeError with older Gradio versions
335
- # If you are still seeing TypeError for 'microphone', your Gradio version might be very old.
336
- # In that case, this component will only support file uploads.
337
- format="wav" # Ensure consistent format
338
- )
339
- transcribe_button = gr.Button("Transcribe Audio")
340
- transcribed_text_output = gr.Textbox(
341
- label="Transcription",
342
- placeholder="Transcription will appear here...",
343
- interactive=False
344
- )
345
- transcribe_button.click(
346
- fn=transcribe_audio,
347
- inputs=[stt_audio_input],
348
- outputs=[transcribed_text_output],
349
- queue=True
350
- )
351
-
352
- # Clear button for the entire interface
353
- def clear_all():
354
- return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output
355
- clear_button = gr.Button("Clear All")
356
- clear_button.click(
357
- clear_all,
358
- inputs=None,
359
- outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output]
360
- )
361
-
362
- # Load all models when the app starts up
363
- load_models()
364
-
365
- # Launch the Gradio app
366
- demo.queue().launch()