ajsbsd commited on
Commit
41baf9f
·
verified ·
1 Parent(s): 05b4227

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -0
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ SpeechT5Processor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ WhisperProcessor, # New: For Speech-to-Text
10
+ WhisperForConditionalGeneration # New: For Speech-to-Text
11
+ )
12
+ from datasets import load_dataset # To get a speaker embedding for TTS
13
+ import os
14
+ import spaces # Import the spaces library for GPU decorator
15
+ import tempfile # For creating temporary audio files
16
+ import soundfile as sf # To save audio files
17
+ import librosa # New: For loading audio files for transcription
18
+
19
+ # --- Configuration for Language Model (LLM) ---
20
+ HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
+ TORCH_DTYPE = torch.bfloat16
22
+ MAX_NEW_TOKENS = 512
23
+ DO_SAMPLE = True
24
+ TEMPERATURE = 0.7
25
+ TOP_K = 50
26
+ TOP_P = 0.95
27
+
28
+ # --- Configuration for Text-to-Speech (TTS) ---
29
+ TTS_MODEL_ID = "microsoft/speecht5_tts"
30
+ TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
31
+
32
+ # --- Configuration for Speech-to-Text (STT) ---
33
+ STT_MODEL_ID = "openai/whisper-tiny" # Using a smaller Whisper model for faster inference
34
+
35
+ # --- Global variables for models and tokenizers/processors ---
36
+ tokenizer = None
37
+ llm_model = None
38
+ tts_processor = None
39
+ tts_model = None
40
+ tts_vocoder = None
41
+ speaker_embeddings = None
42
+ whisper_processor = None # New: Global for Whisper processor
43
+ whisper_model = None # New: Global for Whisper model
44
+
45
+ # --- Load All Models Function ---
46
+ @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
47
+ def load_models():
48
+ """
49
+ Loads the language model, tokenizer, TTS models, speaker embeddings,
50
+ and STT (Whisper) models from Hugging Face Hub.
51
+ This function will be called once when the Gradio app starts up.
52
+ """
53
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
54
+ global whisper_processor, whisper_model
55
+
56
+ if (tokenizer is not None and llm_model is not None and tts_model is not None and
57
+ whisper_processor is not None and whisper_model is not None):
58
+ print("All models and tokenizers/processors already loaded.")
59
+ return
60
+
61
+ hf_token = os.environ.get("HF_TOKEN")
62
+
63
+ # Load Language Model (LLM)
64
+ print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
65
+ try:
66
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
67
+ if tokenizer.pad_token is None:
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
70
+
71
+ print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
72
+ llm_model = AutoModelForCausalLM.from_pretrained(
73
+ HUGGINGFACE_MODEL_ID,
74
+ torch_dtype=TORCH_DTYPE,
75
+ device_map="auto",
76
+ token=hf_token
77
+ )
78
+ llm_model.eval()
79
+ print("LLM model loaded successfully.")
80
+ except Exception as e:
81
+ print(f"Error loading LLM model or tokenizer: {e}")
82
+ raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
83
+
84
+ # Load TTS models
85
+ print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
86
+ try:
87
+ tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
88
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
89
+ tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
90
+
91
+ print("Loading speaker embeddings for TTS...")
92
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
93
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
+
95
+ device = llm_model.device if llm_model else 'cpu'
96
+ tts_model.to(device)
97
+ tts_vocoder.to(device)
98
+ speaker_embeddings = speaker_embeddings.to(device)
99
+ print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
100
+
101
+ except Exception as e:
102
+ print(f"Error loading TTS models or speaker embeddings: {e}")
103
+ tts_processor = None
104
+ tts_model = None
105
+ tts_vocoder = None
106
+ speaker_embeddings = None
107
+ raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
108
+
109
+ # Load STT (Whisper) model
110
+ print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}")
111
+ try:
112
+ whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
113
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
114
+
115
+ device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM
116
+ whisper_model.to(device)
117
+ print(f"STT (Whisper) model loaded successfully to device: {device}.")
118
+ except Exception as e:
119
+ print(f"Error loading STT (Whisper) model or processor: {e}")
120
+ whisper_processor = None
121
+ whisper_model = None
122
+ raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.")
123
+
124
+
125
+ # --- Generate Response and Audio Function ---
126
+ @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
127
+ def generate_response_and_audio(
128
+ message: str, # Current user message
129
+ history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
130
+ ) -> tuple: # Returns (updated_history, audio_file_path)
131
+ """
132
+ Generates a text response from the loaded LLM and then converts it to audio
133
+ using the loaded TTS model.
134
+ """
135
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
136
+
137
+ # Initialize all models if not already loaded
138
+ if tokenizer is None or llm_model is None or tts_model is None:
139
+ load_models()
140
+
141
+ if tokenizer is None or llm_model is None: # Check LLM loading status
142
+ history.append({"role": "user", "content": message})
143
+ history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
144
+ return history, None
145
+
146
+ # --- 1. Generate Text Response (LLM) ---
147
+ messages = history
148
+ messages.append({"role": "user", "content": message})
149
+
150
+ try:
151
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
+ except Exception as e:
153
+ print(f"Error applying chat template: {e}")
154
+ input_text = ""
155
+ for item in history:
156
+ if item["role"] == "user":
157
+ input_text += f"User: {item['content']}\n"
158
+ elif item["role"] == "assistant":
159
+ input_text += f"Assistant: {item['content']}\n"
160
+ input_text += f"User: {message}\nAssistant:"
161
+
162
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
163
+
164
+ with torch.no_grad():
165
+ output_ids = llm_model.generate(
166
+ input_ids,
167
+ max_new_tokens=MAX_NEW_TOKENS,
168
+ do_sample=DO_SAMPLE,
169
+ temperature=TEMPERATURE,
170
+ top_k=TOP_K,
171
+ top_p=TOP_P,
172
+ pad_token_id=tokenizer.eos_token_id
173
+ )
174
+
175
+ generated_token_ids = output_ids[0][input_ids.shape[-1]:]
176
+ generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
177
+
178
+ # --- 2. Generate Audio from Response (TTS) ---
179
+ audio_path = None
180
+ if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
181
+ try:
182
+ device = llm_model.device if llm_model else 'cpu'
183
+ tts_model.to(device)
184
+ tts_vocoder.to(device)
185
+ speaker_embeddings = speaker_embeddings.to(device)
186
+
187
+ tts_inputs = tts_processor(
188
+ text=generated_text,
189
+ return_tensors="pt",
190
+ max_length=550,
191
+ truncation=True
192
+ ).to(device)
193
+
194
+ with torch.no_grad():
195
+ speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
196
+
197
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
198
+ audio_path = tmp_file.name
199
+ sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
200
+ print(f"Audio saved to: {audio_path}")
201
+
202
+ except Exception as e:
203
+ print(f"Error generating audio: {e}")
204
+ audio_path = None
205
+ else:
206
+ print("TTS components not loaded. Skipping audio generation.")
207
+
208
+ # --- 3. Update Chat History ---
209
+ history.append({"role": "assistant", "content": generated_text})
210
+
211
+ return history, audio_path
212
+
213
+
214
+ # --- Transcribe Audio Function (NEW) ---
215
+ @spaces.GPU # This function also needs GPU access for Whisper inference
216
+ def transcribe_audio(audio_filepath):
217
+ """
218
+ Transcribes an audio file using the loaded Whisper model.
219
+ Handles audio files of varying lengths.
220
+ """
221
+ global whisper_processor, whisper_model
222
+
223
+ if whisper_processor is None or whisper_model is None:
224
+ load_models() # Attempt to load if not already loaded
225
+
226
+ if whisper_processor is None or whisper_model is None:
227
+ return "Error: Speech-to-Text model not loaded. Please check logs."
228
+
229
+ if audio_filepath is None:
230
+ return "No audio input provided for transcription."
231
+
232
+ print(f"Transcribing audio from: {audio_filepath}")
233
+ try:
234
+ # Load audio file and resample to 16kHz (Whisper's required sample rate)
235
+ audio, sample_rate = librosa.load(audio_filepath, sr=16000)
236
+
237
+ # Process audio input for the Whisper model
238
+ input_features = whisper_processor(
239
+ audio,
240
+ sampling_rate=sample_rate,
241
+ return_tensors="pt"
242
+ ).input_features.to(whisper_model.device)
243
+
244
+ # Generate transcription IDs
245
+ predicted_ids = whisper_model.generate(input_features)
246
+
247
+ # Decode the IDs to text
248
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
249
+ print(f"Transcription: {transcription}")
250
+ return transcription
251
+
252
+ except Exception as e:
253
+ print(f"Error during transcription: {e}")
254
+ return f"Transcription failed: {e}"
255
+
256
+
257
+ # --- Gradio Interface ---
258
+ with gr.Blocks() as demo:
259
+ gr.Markdown(
260
+ """
261
+ # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output
262
+ Type your message or speak into the microphone to chat with the model.
263
+ The chatbot's response will be spoken, and your audio input can be transcribed!
264
+ """
265
+ )
266
+
267
+ with gr.Tab("Chat with Voice"):
268
+ chatbot = gr.Chatbot(label="Conversation", type='messages')
269
+ with gr.Row():
270
+ text_input = gr.Textbox(
271
+ label="Your message",
272
+ placeholder="Type your message here...",
273
+ scale=4
274
+ )
275
+ submit_button = gr.Button("Send", scale=1)
276
+
277
+ audio_output = gr.Audio(
278
+ label="Listen to Response",
279
+ autoplay=True,
280
+ interactive=False
281
+ )
282
+
283
+ submit_button.click(
284
+ fn=generate_response_and_audio,
285
+ inputs=[text_input, chatbot],
286
+ outputs=[chatbot, audio_output],
287
+ queue=True
288
+ )
289
+ text_input.submit(
290
+ fn=generate_response_and_audio,
291
+ inputs=[text_input, chatbot],
292
+ outputs=[chatbot, audio_output],
293
+ queue=True
294
+ )
295
+
296
+ with gr.Tab("Audio Transcription"):
297
+ stt_audio_input = gr.Audio(
298
+ type="filepath",
299
+ label="Upload Audio or Record from Microphone",
300
+ # Removed 'microphone=True' and 'source' as they cause TypeError with older Gradio versions
301
+ format="wav" # Ensure consistent format
302
+ )
303
+ transcribe_button = gr.Button("Transcribe Audio")
304
+ transcribed_text_output = gr.Textbox(
305
+ label="Transcription",
306
+ placeholder="Transcription will appear here...",
307
+ interactive=False
308
+ )
309
+ transcribe_button.click(
310
+ fn=transcribe_audio,
311
+ inputs=[stt_audio_input],
312
+ outputs=[transcribed_text_output],
313
+ queue=True
314
+ )
315
+
316
+ # Clear button for the entire interface
317
+ def clear_all():
318
+ return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output
319
+ clear_button = gr.Button("Clear All")
320
+ clear_button.click(
321
+ clear_all,
322
+ inputs=None,
323
+ outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output]
324
+ )
325
+
326
+ # Load all models when the app starts up
327
+ load_models()
328
+
329
+ # Launch the Gradio app
330
+ demo.queue().launch()