File size: 14,112 Bytes
41baf9f
 
 
 
 
 
 
 
4d692df
 
41baf9f
 
 
 
 
 
4d692df
41baf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d692df
41baf9f
 
 
 
 
 
 
 
4d692df
 
41baf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
add83be
 
41baf9f
 
 
 
 
add83be
41baf9f
 
 
 
add83be
 
 
41baf9f
add83be
41baf9f
 
 
 
 
 
 
 
add83be
41baf9f
 
add83be
7666164
 
 
 
add83be
7666164
 
 
 
 
 
 
41baf9f
add83be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41baf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d692df
 
41baf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d692df
 
41baf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
    WhisperProcessor, # For Speech-to-Text
    WhisperForConditionalGeneration # For Speech-to-Text
)
from datasets import load_dataset # To get a speaker embedding for TTS
import os
import spaces # Import the spaces library for GPU decorator
import tempfile # For creating temporary audio files
import soundfile as sf # To save audio files
import librosa # For loading audio files for transcription

# --- Configuration for Language Model (LLM) ---
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
TORCH_DTYPE = torch.bfloat16
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.95

# --- Configuration for Text-to-Speech (TTS) ---
TTS_MODEL_ID = "microsoft/speecht5_tts"
TTS_VOCODER_ID = "microsoft/speecht5_hifigan"

# --- Configuration for Speech-to-Text (STT) ---
STT_MODEL_ID = "openai/whisper-small" # Changed from 'openai/whisper-tiny' for better long audio transcription

# --- Global variables for models and tokenizers/processors ---
tokenizer = None
llm_model = None
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None
whisper_processor = None
whisper_model = None

# --- Load All Models Function ---
@spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
def load_models():
    """
    Loads the language model, tokenizer, TTS models, speaker embeddings,
    and STT (Whisper) models from Hugging Face Hub.
    This function will be called once when the Gradio app starts up.
    """
    global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
    global whisper_processor, whisper_model

    if (tokenizer is not None and llm_model is not None and tts_model is not None and
        whisper_processor is not None and whisper_model is not None):
        print("All models and tokenizers/processors already loaded.")
        return

    hf_token = os.environ.get("HF_TOKEN")

    # Load Language Model (LLM)
    print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")

        print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
        llm_model = AutoModelForCausalLM.from_pretrained(
            HUGGINGFACE_MODEL_ID,
            torch_dtype=TORCH_DTYPE,
            device_map="auto",
            token=hf_token
        )
        llm_model.eval()
        print("LLM model loaded successfully.")
    except Exception as e:
        print(f"Error loading LLM model or tokenizer: {e}")
        raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")

    # Load TTS models
    print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
    try:
        tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
        tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
        tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)

        print("Loading speaker embeddings for TTS...")
        embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
        speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)

        device = llm_model.device if llm_model else 'cpu'
        tts_model.to(device)
        tts_vocoder.to(device)
        speaker_embeddings = speaker_embeddings.to(device)
        print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")

    except Exception as e:
        print(f"Error loading TTS models or speaker embeddings: {e}")
        tts_processor = None
        tts_model = None
        tts_vocoder = None
        speaker_embeddings = None
        raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")

    # Load STT (Whisper) model
    print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}")
    try:
        whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
        whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)

        device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM
        whisper_model.to(device)
        print(f"STT (Whisper) model loaded successfully to device: {device}.")
    except Exception as e:
        print(f"Error loading STT (Whisper) model or processor: {e}")
        whisper_processor = None
        whisper_model = None
        raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.")


# --- Generate Response and Audio Function ---
@spaces.GPU
def generate_response_and_audio(message: str, history: list) -> tuple:
    global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings

    if tokenizer is None or llm_model is None or tts_model is None:
        load_models()

    if tokenizer is None or llm_model is None:
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
        return history, None

    # Initialize generated_text early
    generated_text = ""

    # --- 1. Generate Text Response (LLM) ---
    messages = history.copy()
    messages.append({"role": "user", "content": message})

    try:
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception as e:
        print(f"Error applying chat template: {e}")
        input_text = ""
        for item in history:
            input_text += f"{item['role'].capitalize()}: {item['content']}\n"
        input_text += f"User: {message}\nAssistant:"

    try:
        input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
        with torch.no_grad():
            output_ids = llm_model.generate(
                input_ids["input_ids"],
                attention_mask=input_ids["attention_mask"],
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=DO_SAMPLE,
                temperature=TEMPERATURE,
                top_k=TOP_K,
                top_p=TOP_P,
                pad_token_id=tokenizer.eos_token_id
            )

        generated_token_ids = output_ids[0][input_ids["input_ids"].shape[-1]:]
        generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()

    except Exception as e:
        print(f"Error during LLM generation: {e}")
        history.append({"role": "assistant", "content": "I encountered an error while generating a response."})
        return history, None

    # --- 2. Generate Audio from Response (TTS) ---
    audio_path = None
    if all([tts_processor, tts_model, tts_vocoder, speaker_embeddings]):
        try:
            device = llm_model.device if llm_model else 'cpu'
            tts_model.to(device)
            tts_vocoder.to(device)
            speaker_embeddings = speaker_embeddings.to(device)

            tts_inputs = tts_processor(
                text=generated_text,
                return_tensors="pt",
                max_length=550,
                truncation=True
            ).to(device)

            with torch.no_grad():
                speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)

            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                audio_path = tmp_file.name
                sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)

            print(f"Audio saved to: {audio_path}")

        except Exception as e:
            print(f"Error generating audio: {e}")
            audio_path = None
    else:
        print("TTS components not fully loaded. Skipping audio generation.")

    # --- 3. Update Chat History ---
    history.append({"role": "assistant", "content": generated_text})
    return history, audio_path
    # --- 2. Generate Audio from Response (TTS) ---
    audio_path = None
    if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
        try:
            device = llm_model.device if llm_model else 'cpu'
            tts_model.to(device)
            tts_vocoder.to(device)
            speaker_embeddings = speaker_embeddings.to(device)

            tts_inputs = tts_processor(
                text=generated_text,
                return_tensors="pt",
                max_length=550,
                truncation=True
            ).to(device)

            with torch.no_grad():
                speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)

            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                audio_path = tmp_file.name
                sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
            print(f"Audio saved to: {audio_path}")

        except Exception as e:
            print(f"Error generating audio: {e}")
            audio_path = None
    else:
        print("TTS components not loaded. Skipping audio generation.")

    # --- 3. Update Chat History ---
    history.append({"role": "assistant", "content": generated_text})

    return history, audio_path


# --- Transcribe Audio Function (NEW) ---
@spaces.GPU # This function also needs GPU access for Whisper inference
def transcribe_audio(audio_filepath):
    """
    Transcribes an audio file using the loaded Whisper model.
    Handles audio files of varying lengths.
    """
    global whisper_processor, whisper_model

    if whisper_processor is None or whisper_model is None:
        load_models() # Attempt to load if not already loaded

    if whisper_processor is None or whisper_model is None:
        return "Error: Speech-to-Text model not loaded. Please check logs."

    if audio_filepath is None:
        return "No audio input provided for transcription."

    print(f"Transcribing audio from: {audio_filepath}")
    try:
        # Load audio file and resample to 16kHz (Whisper's required sample rate)
        audio, sample_rate = librosa.load(audio_filepath, sr=16000)

        # Process audio input for the Whisper model
        # The Whisper `generate` method, especially with larger models, is designed
        # to handle variable-length inputs by internally managing context.
        input_features = whisper_processor(
            audio,
            sampling_rate=sample_rate,
            return_tensors="pt"
        ).input_features.to(whisper_model.device)

        # Generate transcription IDs
        predicted_ids = whisper_model.generate(input_features)

        # Decode the IDs to text
        transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        print(f"Transcription: {transcription}")
        return transcription

    except Exception as e:
        print(f"Error during transcription: {e}")
        return f"Transcription failed: {e}"


# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output
        Type your message or speak into the microphone to chat with the model.
        The chatbot's response will be spoken, and your audio input can be transcribed!
        """
    )

    with gr.Tab("Chat with Voice"):
        chatbot = gr.Chatbot(label="Conversation", type='messages')
        with gr.Row():
            text_input = gr.Textbox(
                label="Your message",
                placeholder="Type your message here...",
                scale=4
            )
            submit_button = gr.Button("Send", scale=1)

        audio_output = gr.Audio(
            label="Listen to Response",
            autoplay=True,
            interactive=False
        )

        submit_button.click(
            fn=generate_response_and_audio,
            inputs=[text_input, chatbot],
            outputs=[chatbot, audio_output],
            queue=True
        )
        text_input.submit(
            fn=generate_response_and_audio,
            inputs=[text_input, chatbot],
            outputs=[chatbot, audio_output],
            queue=True
        )

    with gr.Tab("Audio Transcription"):
        stt_audio_input = gr.Audio(
            type="filepath",
            label="Upload Audio or Record from Microphone",
            # Removed 'microphone=True' and 'source' as they cause TypeError with older Gradio versions
            # If you are still seeing TypeError for 'microphone', your Gradio version might be very old.
            # In that case, this component will only support file uploads.
            format="wav" # Ensure consistent format
        )
        transcribe_button = gr.Button("Transcribe Audio")
        transcribed_text_output = gr.Textbox(
            label="Transcription",
            placeholder="Transcription will appear here...",
            interactive=False
        )
        transcribe_button.click(
            fn=transcribe_audio,
            inputs=[stt_audio_input],
            outputs=[transcribed_text_output],
            queue=True
        )

    # Clear button for the entire interface
    def clear_all():
        return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output
    clear_button = gr.Button("Clear All")
    clear_button.click(
        clear_all,
        inputs=None,
        outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output]
    )

# Load all models when the app starts up
load_models()

# Launch the Gradio app
demo.queue().launch()