import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan from datasets import load_dataset # To get a speaker embedding for TTS import os import spaces # Import the spaces library for GPU decorator import tempfile # For creating temporary audio files import soundfile as sf # To save audio files # --- Configuration for Language Model (LLM) --- # IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID # rather than a local path ('.'), as the Space will fetch it from the Hub. HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" # You might need to adjust TORCH_DTYPE based on your GPU and model support # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs # For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows. TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance # Generation parameters for the LLM (can be adjusted for different response styles) MAX_NEW_TOKENS = 512 DO_SAMPLE = True TEMPERATURE = 0.7 TOP_K = 50 TOP_P = 0.95 # --- Configuration for Text-to-Speech (TTS) --- TTS_MODEL_ID = "microsoft/speecht5_tts" TTS_VOCODER_ID = "microsoft/speecht5_hifigan" # --- Global variables for models and tokenizers --- tokenizer = None llm_model = None # Renamed to avoid conflict with tts_model tts_processor = None tts_model = None tts_vocoder = None speaker_embeddings = None # Global for TTS speaker embedding # --- Load Models and Tokenizers Function --- @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access def load_models(): """ Loads the language model, tokenizer, TTS models, and speaker embeddings from Hugging Face Hub. This function will be called once when the Gradio app starts up. """ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings if tokenizer is not None and llm_model is not None and tts_model is not None: print("All models and tokenizers already loaded.") return # When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN # for public models, but it's good practice for private models or if # rate limits are hit. hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set # Load Language Model (LLM) print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}") try: tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...") llm_model = AutoModelForCausalLM.from_pretrained( HUGGINGFACE_MODEL_ID, torch_dtype=TORCH_DTYPE, device_map="auto", # Automatically maps model to GPU if available, else CPU token=hf_token # Pass token if loading private model ) llm_model.eval() # Set model to evaluation mode print("LLM model loaded successfully.") except Exception as e: print(f"Error loading LLM model or tokenizer: {e}") print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.") tokenizer = None llm_model = None raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.") # Load TTS models print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}") try: tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token) tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token) tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token) # Load a speaker embedding (essential for SpeechT5 TTS) # Using a sample from a public dataset for demonstration print("Loading speaker embeddings for TTS...") embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token) # Using a specific speaker embedding (you can experiment with different indices) speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # Move TTS components to the same device as the LLM model device = llm_model.device if llm_model else 'cpu' tts_model.to(device) tts_vocoder.to(device) speaker_embeddings = speaker_embeddings.to(device) print(f"TTS models and speaker embeddings loaded successfully to device: {device}.") except Exception as e: print(f"Error loading TTS models or speaker embeddings: {e}") print("Please ensure TTS model IDs are correct and you have an internet connection.") tts_processor = None tts_model = None tts_vocoder = None speaker_embeddings = None raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.") # --- Generate Response and Audio Function --- @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference def generate_response_and_audio( message: str, # Current user message history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content') ) -> tuple: # Returns (updated_history, audio_file_path) """ Generates a text response from the loaded LLM and then converts it to audio using the loaded TTS model. """ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings # Initialize all models if not already loaded if tokenizer is None or llm_model is None or tts_model is None: load_models() if tokenizer is None or llm_model is None: # Check LLM loading status history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."}) return history, None # --- 1. Generate Text Response (LLM) --- # Format messages for the model's chat template messages = history # Use history directly as it's already in the correct format messages.append({"role": "user", "content": message}) # Add current user message # Apply the chat template and tokenize try: input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception as e: print(f"Error applying chat template: {e}") # Fallback for models without explicit chat templates input_text = "" for item in history: if item["role"] == "user": input_text += f"User: {item['content']}\n" elif item["role"] == "assistant": input_text += f"Assistant: {item['content']}\n" input_text += f"User: {message}\nAssistant:" input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device) # Generate response with torch.no_grad(): # Disable gradient calculations for inference output_ids = llm_model.generate( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=DO_SAMPLE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly ) # Decode the generated text, excluding the input prompt part generated_token_ids = output_ids[0][input_ids.shape[-1]:] generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() # --- 2. Generate Audio from Response (TTS) --- audio_path = None if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None: try: # Ensure TTS components are on the correct device device = llm_model.device if llm_model else 'cpu' tts_model.to(device) tts_vocoder.to(device) speaker_embeddings = speaker_embeddings.to(device) tts_inputs = tts_processor( text=generated_text, return_tensors="pt", max_length=550, # Set a max length to prevent excessively long audio truncation=True # Enable truncation if text exceeds max_length ).to(device) with torch.no_grad(): speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder) # Create a temporary file to save the audio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: audio_path = tmp_file.name # Ensure audio data is on CPU before saving with soundfile sf.write(audio_path, speech.cpu().numpy(), samplerate=16000) print(f"Audio saved to: {audio_path}") except Exception as e: print(f"Error generating audio: {e}") audio_path = None # Return None if audio generation fails else: print("TTS components not loaded. Skipping audio generation.") # --- 3. Update Chat History --- # Append the latest generated response to the history with its role history.append({"role": "assistant", "content": generated_text}) return history, audio_path # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown( """ # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot Type your message below and chat with the model! """ ) # Set type='messages' for the chatbot to use OpenAI-style dictionaries chatbot = gr.Chatbot(label="Conversation", type='messages') with gr.Row(): text_input = gr.Textbox( label="Your message", placeholder="Type your message here...", scale=4 ) submit_button = gr.Button("Send", scale=1) audio_output = gr.Audio( label="Listen to Response", autoplay=True, # Automatically play audio interactive=False # Don't allow user to interact with this audio component ) # Link the text input and button to the generation function # Outputs now include both the chatbot history and the audio file path submit_button.click( fn=generate_response_and_audio, inputs=[text_input, chatbot], outputs=[chatbot, audio_output], queue=True # Queue requests for better concurrency ) text_input.submit( # Also trigger on Enter key fn=generate_response_and_audio, inputs=[text_input, chatbot], outputs=[chatbot, audio_output], queue=True ) # Clear button def clear_chat(): # Clear history, text input, and audio output return [], "", None clear_button = gr.Button("Clear Chat") clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output]) # Load all models when the app starts up load_models() # Launch the Gradio app demo.queue().launch()