ajsbsd's picture
Update app.py
28ccf5e verified
raw
history blame
11.3 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset # To get a speaker embedding for TTS
import os
import spaces # Import the spaces library for GPU decorator
import tempfile # For creating temporary audio files
import soundfile as sf # To save audio files
# --- Configuration for Language Model (LLM) ---
# IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID
# rather than a local path ('.'), as the Space will fetch it from the Hub.
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
# You might need to adjust TORCH_DTYPE based on your GPU and model support
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
# For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows.
TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance
# Generation parameters for the LLM (can be adjusted for different response styles)
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.95
# --- Configuration for Text-to-Speech (TTS) ---
TTS_MODEL_ID = "microsoft/speecht5_tts"
TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
# --- Global variables for models and tokenizers ---
tokenizer = None
llm_model = None # Renamed to avoid conflict with tts_model
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None # Global for TTS speaker embedding
# --- Load Models and Tokenizers Function ---
@spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
def load_models():
"""
Loads the language model, tokenizer, TTS models, and speaker embeddings
from Hugging Face Hub. This function will be called once when the Gradio app starts up.
"""
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
if tokenizer is not None and llm_model is not None and tts_model is not None:
print("All models and tokenizers already loaded.")
return
# When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN
# for public models, but it's good practice for private models or if
# rate limits are hit.
hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set
# Load Language Model (LLM)
print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
try:
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
llm_model = AutoModelForCausalLM.from_pretrained(
HUGGINGFACE_MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto", # Automatically maps model to GPU if available, else CPU
token=hf_token # Pass token if loading private model
)
llm_model.eval() # Set model to evaluation mode
print("LLM model loaded successfully.")
except Exception as e:
print(f"Error loading LLM model or tokenizer: {e}")
print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.")
tokenizer = None
llm_model = None
raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
# Load TTS models
print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
try:
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
# Load a speaker embedding (essential for SpeechT5 TTS)
# Using a sample from a public dataset for demonstration
print("Loading speaker embeddings for TTS...")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
# Using a specific speaker embedding (you can experiment with different indices)
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
# Move TTS components to the same device as the LLM model
device = llm_model.device if llm_model else 'cpu'
tts_model.to(device)
tts_vocoder.to(device)
speaker_embeddings = speaker_embeddings.to(device)
print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
except Exception as e:
print(f"Error loading TTS models or speaker embeddings: {e}")
print("Please ensure TTS model IDs are correct and you have an internet connection.")
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None
raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
# --- Generate Response and Audio Function ---
@spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
def generate_response_and_audio(
message: str, # Current user message
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
) -> tuple: # Returns (updated_history, audio_file_path)
"""
Generates a text response from the loaded LLM and then converts it to audio
using the loaded TTS model.
"""
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
# Initialize all models if not already loaded
if tokenizer is None or llm_model is None or tts_model is None:
load_models()
if tokenizer is None or llm_model is None: # Check LLM loading status
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
return history, None
# --- 1. Generate Text Response (LLM) ---
# Format messages for the model's chat template
messages = history # Use history directly as it's already in the correct format
messages.append({"role": "user", "content": message}) # Add current user message
# Apply the chat template and tokenize
try:
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception as e:
print(f"Error applying chat template: {e}")
# Fallback for models without explicit chat templates
input_text = ""
for item in history:
if item["role"] == "user":
input_text += f"User: {item['content']}\n"
elif item["role"] == "assistant":
input_text += f"Assistant: {item['content']}\n"
input_text += f"User: {message}\nAssistant:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
# Generate response
with torch.no_grad(): # Disable gradient calculations for inference
output_ids = llm_model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
top_k=TOP_K,
top_p=TOP_P,
pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
)
# Decode the generated text, excluding the input prompt part
generated_token_ids = output_ids[0][input_ids.shape[-1]:]
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
# --- 2. Generate Audio from Response (TTS) ---
audio_path = None
if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
try:
# Ensure TTS components are on the correct device
device = llm_model.device if llm_model else 'cpu'
tts_model.to(device)
tts_vocoder.to(device)
speaker_embeddings = speaker_embeddings.to(device)
tts_inputs = tts_processor(
text=generated_text,
return_tensors="pt",
max_length=550, # Set a max length to prevent excessively long audio
truncation=True # Enable truncation if text exceeds max_length
).to(device)
with torch.no_grad():
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
# Create a temporary file to save the audio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
audio_path = tmp_file.name
# Ensure audio data is on CPU before saving with soundfile
sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
print(f"Audio saved to: {audio_path}")
except Exception as e:
print(f"Error generating audio: {e}")
audio_path = None # Return None if audio generation fails
else:
print("TTS components not loaded. Skipping audio generation.")
# --- 3. Update Chat History ---
# Append the latest generated response to the history with its role
history.append({"role": "assistant", "content": generated_text})
return history, audio_path
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot
Type your message below and chat with the model!
"""
)
# Set type='messages' for the chatbot to use OpenAI-style dictionaries
chatbot = gr.Chatbot(label="Conversation", type='messages')
with gr.Row():
text_input = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
scale=4
)
submit_button = gr.Button("Send", scale=1)
audio_output = gr.Audio(
label="Listen to Response",
autoplay=True, # Automatically play audio
interactive=False # Don't allow user to interact with this audio component
)
# Link the text input and button to the generation function
# Outputs now include both the chatbot history and the audio file path
submit_button.click(
fn=generate_response_and_audio,
inputs=[text_input, chatbot],
outputs=[chatbot, audio_output],
queue=True # Queue requests for better concurrency
)
text_input.submit( # Also trigger on Enter key
fn=generate_response_and_audio,
inputs=[text_input, chatbot],
outputs=[chatbot, audio_output],
queue=True
)
# Clear button
def clear_chat():
# Clear history, text input, and audio output
return [], "", None
clear_button = gr.Button("Clear Chat")
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output])
# Load all models when the app starts up
load_models()
# Launch the Gradio app
demo.queue().launch()