File size: 11,276 Bytes
5995f30
 
28ccf5e
 
5995f30
28ccf5e
 
 
5995f30
28ccf5e
 
 
61ee61a
5995f30
 
 
28ccf5e
 
5995f30
28ccf5e
5995f30
 
 
 
 
 
28ccf5e
 
 
 
5995f30
 
28ccf5e
 
 
 
 
5995f30
 
28ccf5e
 
5995f30
28ccf5e
 
5995f30
28ccf5e
5995f30
28ccf5e
 
5995f30
 
28ccf5e
 
 
 
 
 
 
5995f30
28ccf5e
5995f30
 
 
 
28ccf5e
 
5995f30
 
28ccf5e
 
5995f30
28ccf5e
 
5995f30
28ccf5e
 
5995f30
28ccf5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5995f30
28ccf5e
 
 
 
 
 
5995f30
28ccf5e
 
 
 
 
 
 
 
 
 
 
 
 
5995f30
 
28ccf5e
5995f30
28ccf5e
 
5995f30
28ccf5e
5995f30
28ccf5e
 
 
5995f30
28ccf5e
5995f30
28ccf5e
 
5995f30
28ccf5e
 
5995f30
 
 
 
 
 
 
 
28ccf5e
5995f30
 
 
 
 
 
 
 
28ccf5e
5995f30
 
 
28ccf5e
5995f30
 
 
 
 
 
 
 
 
 
 
 
 
28ccf5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5995f30
 
 
28ccf5e
5995f30
 
 
 
 
7315d1a
 
5995f30
 
 
 
 
 
 
 
 
 
 
 
 
28ccf5e
 
 
 
 
 
5995f30
28ccf5e
5995f30
28ccf5e
 
 
5995f30
 
 
28ccf5e
5995f30
28ccf5e
5995f30
 
 
 
 
28ccf5e
 
5995f30
28ccf5e
5995f30
 
28ccf5e
 
5995f30
 
28ccf5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset # To get a speaker embedding for TTS
import os
import spaces # Import the spaces library for GPU decorator
import tempfile # For creating temporary audio files
import soundfile as sf # To save audio files

# --- Configuration for Language Model (LLM) ---
# IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID
# rather than a local path ('.'), as the Space will fetch it from the Hub.
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"

# You might need to adjust TORCH_DTYPE based on your GPU and model support
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
# For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows.
TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance

# Generation parameters for the LLM (can be adjusted for different response styles)
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.95

# --- Configuration for Text-to-Speech (TTS) ---
TTS_MODEL_ID = "microsoft/speecht5_tts"
TTS_VOCODER_ID = "microsoft/speecht5_hifigan"

# --- Global variables for models and tokenizers ---
tokenizer = None
llm_model = None # Renamed to avoid conflict with tts_model
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None # Global for TTS speaker embedding

# --- Load Models and Tokenizers Function ---
@spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
def load_models():
    """
    Loads the language model, tokenizer, TTS models, and speaker embeddings
    from Hugging Face Hub. This function will be called once when the Gradio app starts up.
    """
    global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings

    if tokenizer is not None and llm_model is not None and tts_model is not None:
        print("All models and tokenizers already loaded.")
        return

    # When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN
    # for public models, but it's good practice for private models or if
    # rate limits are hit.
    hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set

    # Load Language Model (LLM)
    print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")

        print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
        llm_model = AutoModelForCausalLM.from_pretrained(
            HUGGINGFACE_MODEL_ID,
            torch_dtype=TORCH_DTYPE,
            device_map="auto", # Automatically maps model to GPU if available, else CPU
            token=hf_token # Pass token if loading private model
        )
        llm_model.eval() # Set model to evaluation mode
        print("LLM model loaded successfully.")
    except Exception as e:
        print(f"Error loading LLM model or tokenizer: {e}")
        print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.")
        tokenizer = None
        llm_model = None
        raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")

    # Load TTS models
    print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
    try:
        tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
        tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
        tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)

        # Load a speaker embedding (essential for SpeechT5 TTS)
        # Using a sample from a public dataset for demonstration
        print("Loading speaker embeddings for TTS...")
        embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
        # Using a specific speaker embedding (you can experiment with different indices)
        speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)

        # Move TTS components to the same device as the LLM model
        device = llm_model.device if llm_model else 'cpu'
        tts_model.to(device)
        tts_vocoder.to(device)
        speaker_embeddings = speaker_embeddings.to(device)
        print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")

    except Exception as e:
        print(f"Error loading TTS models or speaker embeddings: {e}")
        print("Please ensure TTS model IDs are correct and you have an internet connection.")
        tts_processor = None
        tts_model = None
        tts_vocoder = None
        speaker_embeddings = None
        raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")


# --- Generate Response and Audio Function ---
@spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
def generate_response_and_audio(
    message: str, # Current user message
    history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
) -> tuple: # Returns (updated_history, audio_file_path)
    """
    Generates a text response from the loaded LLM and then converts it to audio
    using the loaded TTS model.
    """
    global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings

    # Initialize all models if not already loaded
    if tokenizer is None or llm_model is None or tts_model is None:
        load_models()

    if tokenizer is None or llm_model is None: # Check LLM loading status
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
        return history, None

    # --- 1. Generate Text Response (LLM) ---
    # Format messages for the model's chat template
    messages = history # Use history directly as it's already in the correct format
    messages.append({"role": "user", "content": message}) # Add current user message

    # Apply the chat template and tokenize
    try:
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception as e:
        print(f"Error applying chat template: {e}")
        # Fallback for models without explicit chat templates
        input_text = ""
        for item in history:
            if item["role"] == "user":
                input_text += f"User: {item['content']}\n"
            elif item["role"] == "assistant":
                input_text += f"Assistant: {item['content']}\n"
        input_text += f"User: {message}\nAssistant:"

    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)

    # Generate response
    with torch.no_grad(): # Disable gradient calculations for inference
        output_ids = llm_model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            top_p=TOP_P,
            pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
        )

    # Decode the generated text, excluding the input prompt part
    generated_token_ids = output_ids[0][input_ids.shape[-1]:]
    generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()

    # --- 2. Generate Audio from Response (TTS) ---
    audio_path = None
    if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
        try:
            # Ensure TTS components are on the correct device
            device = llm_model.device if llm_model else 'cpu'
            tts_model.to(device)
            tts_vocoder.to(device)
            speaker_embeddings = speaker_embeddings.to(device)

            tts_inputs = tts_processor(
                text=generated_text,
                return_tensors="pt",
                max_length=550, # Set a max length to prevent excessively long audio
                truncation=True # Enable truncation if text exceeds max_length
            ).to(device)

            with torch.no_grad():
                speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)

            # Create a temporary file to save the audio
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                audio_path = tmp_file.name
                # Ensure audio data is on CPU before saving with soundfile
                sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
            print(f"Audio saved to: {audio_path}")

        except Exception as e:
            print(f"Error generating audio: {e}")
            audio_path = None # Return None if audio generation fails
    else:
        print("TTS components not loaded. Skipping audio generation.")


    # --- 3. Update Chat History ---
    # Append the latest generated response to the history with its role
    history.append({"role": "assistant", "content": generated_text})

    return history, audio_path

# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot
        Type your message below and chat with the model!
        """
    )

    # Set type='messages' for the chatbot to use OpenAI-style dictionaries
    chatbot = gr.Chatbot(label="Conversation", type='messages')
    with gr.Row():
        text_input = gr.Textbox(
            label="Your message",
            placeholder="Type your message here...",
            scale=4
        )
        submit_button = gr.Button("Send", scale=1)

    audio_output = gr.Audio(
        label="Listen to Response",
        autoplay=True, # Automatically play audio
        interactive=False # Don't allow user to interact with this audio component
    )

    # Link the text input and button to the generation function
    # Outputs now include both the chatbot history and the audio file path
    submit_button.click(
        fn=generate_response_and_audio,
        inputs=[text_input, chatbot],
        outputs=[chatbot, audio_output],
        queue=True # Queue requests for better concurrency
    )
    text_input.submit( # Also trigger on Enter key
        fn=generate_response_and_audio,
        inputs=[text_input, chatbot],
        outputs=[chatbot, audio_output],
        queue=True
    )

    # Clear button
    def clear_chat():
        # Clear history, text input, and audio output
        return [], "", None
    clear_button = gr.Button("Clear Chat")
    clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output])


# Load all models when the app starts up
load_models()

# Launch the Gradio app
demo.queue().launch()