ajsbsd commited on
Commit
05b4227
·
verified ·
1 Parent(s): 39cc919

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -330
app.py DELETED
@@ -1,330 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- SpeechT5Processor,
7
- SpeechT5ForTextToSpeech,
8
- SpeechT5HifiGan,
9
- WhisperProcessor, # New: For Speech-to-Text
10
- WhisperForConditionalGeneration # New: For Speech-to-Text
11
- )
12
- from datasets import load_dataset # To get a speaker embedding for TTS
13
- import os
14
- import spaces # Import the spaces library for GPU decorator
15
- import tempfile # For creating temporary audio files
16
- import soundfile as sf # To save audio files
17
- import librosa # New: For loading audio files for transcription
18
-
19
- # --- Configuration for Language Model (LLM) ---
20
- HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
- TORCH_DTYPE = torch.bfloat16
22
- MAX_NEW_TOKENS = 512
23
- DO_SAMPLE = True
24
- TEMPERATURE = 0.7
25
- TOP_K = 50
26
- TOP_P = 0.95
27
-
28
- # --- Configuration for Text-to-Speech (TTS) ---
29
- TTS_MODEL_ID = "microsoft/speecht5_tts"
30
- TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
31
-
32
- # --- Configuration for Speech-to-Text (STT) ---
33
- STT_MODEL_ID = "openai/whisper-tiny" # Using a smaller Whisper model for faster inference
34
-
35
- # --- Global variables for models and tokenizers/processors ---
36
- tokenizer = None
37
- llm_model = None
38
- tts_processor = None
39
- tts_model = None
40
- tts_vocoder = None
41
- speaker_embeddings = None
42
- whisper_processor = None # New: Global for Whisper processor
43
- whisper_model = None # New: Global for Whisper model
44
-
45
- # --- Load All Models Function ---
46
- @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
47
- def load_models():
48
- """
49
- Loads the language model, tokenizer, TTS models, speaker embeddings,
50
- and STT (Whisper) models from Hugging Face Hub.
51
- This function will be called once when the Gradio app starts up.
52
- """
53
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
54
- global whisper_processor, whisper_model
55
-
56
- if (tokenizer is not None and llm_model is not None and tts_model is not None and
57
- whisper_processor is not None and whisper_model is not None):
58
- print("All models and tokenizers/processors already loaded.")
59
- return
60
-
61
- hf_token = os.environ.get("HF_TOKEN")
62
-
63
- # Load Language Model (LLM)
64
- print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
65
- try:
66
- tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
67
- if tokenizer.pad_token is None:
68
- tokenizer.pad_token = tokenizer.eos_token
69
- print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
70
-
71
- print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
72
- llm_model = AutoModelForCausalLM.from_pretrained(
73
- HUGGINGFACE_MODEL_ID,
74
- torch_dtype=TORCH_DTYPE,
75
- device_map="auto",
76
- token=hf_token
77
- )
78
- llm_model.eval()
79
- print("LLM model loaded successfully.")
80
- except Exception as e:
81
- print(f"Error loading LLM model or tokenizer: {e}")
82
- raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
83
-
84
- # Load TTS models
85
- print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
86
- try:
87
- tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
88
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
89
- tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
90
-
91
- print("Loading speaker embeddings for TTS...")
92
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
93
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
-
95
- device = llm_model.device if llm_model else 'cpu'
96
- tts_model.to(device)
97
- tts_vocoder.to(device)
98
- speaker_embeddings = speaker_embeddings.to(device)
99
- print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
100
-
101
- except Exception as e:
102
- print(f"Error loading TTS models or speaker embeddings: {e}")
103
- tts_processor = None
104
- tts_model = None
105
- tts_vocoder = None
106
- speaker_embeddings = None
107
- raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
108
-
109
- # Load STT (Whisper) model
110
- print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}")
111
- try:
112
- whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
113
- whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
114
-
115
- device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM
116
- whisper_model.to(device)
117
- print(f"STT (Whisper) model loaded successfully to device: {device}.")
118
- except Exception as e:
119
- print(f"Error loading STT (Whisper) model or processor: {e}")
120
- whisper_processor = None
121
- whisper_model = None
122
- raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.")
123
-
124
-
125
- # --- Generate Response and Audio Function ---
126
- @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
127
- def generate_response_and_audio(
128
- message: str, # Current user message
129
- history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
130
- ) -> tuple: # Returns (updated_history, audio_file_path)
131
- """
132
- Generates a text response from the loaded LLM and then converts it to audio
133
- using the loaded TTS model.
134
- """
135
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
136
-
137
- # Initialize all models if not already loaded
138
- if tokenizer is None or llm_model is None or tts_model is None:
139
- load_models()
140
-
141
- if tokenizer is None or llm_model is None: # Check LLM loading status
142
- history.append({"role": "user", "content": message})
143
- history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
144
- return history, None
145
-
146
- # --- 1. Generate Text Response (LLM) ---
147
- messages = history
148
- messages.append({"role": "user", "content": message})
149
-
150
- try:
151
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
- except Exception as e:
153
- print(f"Error applying chat template: {e}")
154
- input_text = ""
155
- for item in history:
156
- if item["role"] == "user":
157
- input_text += f"User: {item['content']}\n"
158
- elif item["role"] == "assistant":
159
- input_text += f"Assistant: {item['content']}\n"
160
- input_text += f"User: {message}\nAssistant:"
161
-
162
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
163
-
164
- with torch.no_grad():
165
- output_ids = llm_model.generate(
166
- input_ids,
167
- max_new_tokens=MAX_NEW_TOKENS,
168
- do_sample=DO_SAMPLE,
169
- temperature=TEMPERATURE,
170
- top_k=TOP_K,
171
- top_p=TOP_P,
172
- pad_token_id=tokenizer.eos_token_id
173
- )
174
-
175
- generated_token_ids = output_ids[0][input_ids.shape[-1]:]
176
- generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
177
-
178
- # --- 2. Generate Audio from Response (TTS) ---
179
- audio_path = None
180
- if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
181
- try:
182
- device = llm_model.device if llm_model else 'cpu'
183
- tts_model.to(device)
184
- tts_vocoder.to(device)
185
- speaker_embeddings = speaker_embeddings.to(device)
186
-
187
- tts_inputs = tts_processor(
188
- text=generated_text,
189
- return_tensors="pt",
190
- max_length=550,
191
- truncation=True
192
- ).to(device)
193
-
194
- with torch.no_grad():
195
- speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
196
-
197
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
198
- audio_path = tmp_file.name
199
- sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
200
- print(f"Audio saved to: {audio_path}")
201
-
202
- except Exception as e:
203
- print(f"Error generating audio: {e}")
204
- audio_path = None
205
- else:
206
- print("TTS components not loaded. Skipping audio generation.")
207
-
208
- # --- 3. Update Chat History ---
209
- history.append({"role": "assistant", "content": generated_text})
210
-
211
- return history, audio_path
212
-
213
-
214
- # --- Transcribe Audio Function (NEW) ---
215
- @spaces.GPU # This function also needs GPU access for Whisper inference
216
- def transcribe_audio(audio_filepath):
217
- """
218
- Transcribes an audio file using the loaded Whisper model.
219
- Handles audio files of varying lengths.
220
- """
221
- global whisper_processor, whisper_model
222
-
223
- if whisper_processor is None or whisper_model is None: # Corrected '===' to 'is'
224
- load_models() # Attempt to load if not already loaded
225
-
226
- if whisper_processor is None or whisper_model is None:
227
- return "Error: Speech-to-Text model not loaded. Please check logs."
228
-
229
- if audio_filepath is None:
230
- return "No audio input provided for transcription."
231
-
232
- print(f"Transcribing audio from: {audio_filepath}")
233
- try:
234
- # Load audio file and resample to 16kHz (Whisper's required sample rate)
235
- audio, sample_rate = librosa.load(audio_filepath, sr=16000)
236
-
237
- # Process audio input for the Whisper model
238
- input_features = whisper_processor(
239
- audio,
240
- sampling_rate=sample_rate,
241
- return_tensors="pt"
242
- ).input_features.to(whisper_model.device)
243
-
244
- # Generate transcription IDs
245
- predicted_ids = whisper_model.generate(input_features)
246
-
247
- # Decode the IDs to text
248
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
249
- print(f"Transcription: {transcription}")
250
- return transcription
251
-
252
- except Exception as e:
253
- print(f"Error during transcription: {e}")
254
- return f"Transcription failed: {e}"
255
-
256
-
257
- # --- Gradio Interface ---
258
- with gr.Blocks() as demo:
259
- gr.Markdown(
260
- """
261
- # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output
262
- Type your message or speak into the microphone to chat with the model.
263
- The chatbot's response will be spoken, and your audio input can be transcribed!
264
- """
265
- )
266
-
267
- with gr.Tab("Chat with Voice"):
268
- chatbot = gr.Chatbot(label="Conversation", type='messages')
269
- with gr.Row():
270
- text_input = gr.Textbox(
271
- label="Your message",
272
- placeholder="Type your message here...",
273
- scale=4
274
- )
275
- submit_button = gr.Button("Send", scale=1)
276
-
277
- audio_output = gr.Audio(
278
- label="Listen to Response",
279
- autoplay=True,
280
- interactive=False
281
- )
282
-
283
- submit_button.click(
284
- fn=generate_response_and_audio,
285
- inputs=[text_input, chatbot],
286
- outputs=[chatbot, audio_output],
287
- queue=True
288
- )
289
- text_input.submit(
290
- fn=generate_response_and_audio,
291
- inputs=[text_input, chatbot],
292
- outputs=[chatbot, audio_output],
293
- queue=True
294
- )
295
-
296
- with gr.Tab("Audio Transcription"):
297
- stt_audio_input = gr.Audio(
298
- type="filepath",
299
- label="Upload Audio or Record from Microphone",
300
- microphone=True, # Changed from 'source="microphone"'
301
- format="wav" # Ensure consistent format
302
- )
303
- transcribe_button = gr.Button("Transcribe Audio")
304
- transcribed_text_output = gr.Textbox(
305
- label="Transcription",
306
- placeholder="Transcription will appear here...",
307
- interactive=False
308
- )
309
- transcribe_button.click(
310
- fn=transcribe_audio,
311
- inputs=[stt_audio_input],
312
- outputs=[transcribed_text_output],
313
- queue=True
314
- )
315
-
316
- # Clear button for the entire interface
317
- def clear_all():
318
- return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output
319
- clear_button = gr.Button("Clear All")
320
- clear_button.click(
321
- clear_all,
322
- inputs=None,
323
- outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output]
324
- )
325
-
326
- # Load all models when the app starts up
327
- load_models()
328
-
329
- # Launch the Gradio app
330
- demo.queue().launch()