Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
SpeechT5Processor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan, | |
WhisperProcessor, # New: For Speech-to-Text | |
WhisperForConditionalGeneration # New: For Speech-to-Text | |
) | |
from datasets import load_dataset # To get a speaker embedding for TTS | |
import os | |
import spaces # Import the spaces library for GPU decorator | |
import tempfile # For creating temporary audio files | |
import soundfile as sf # To save audio files | |
import librosa # New: For loading audio files for transcription | |
# --- Configuration for Language Model (LLM) --- | |
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" | |
TORCH_DTYPE = torch.bfloat16 | |
MAX_NEW_TOKENS = 512 | |
DO_SAMPLE = True | |
TEMPERATURE = 0.7 | |
TOP_K = 50 | |
TOP_P = 0.95 | |
# --- Configuration for Text-to-Speech (TTS) --- | |
TTS_MODEL_ID = "microsoft/speecht5_tts" | |
TTS_VOCODER_ID = "microsoft/speecht5_hifigan" | |
# --- Configuration for Speech-to-Text (STT) --- | |
STT_MODEL_ID = "openai/whisper-tiny" # Using a smaller Whisper model for faster inference | |
# --- Global variables for models and tokenizers/processors --- | |
tokenizer = None | |
llm_model = None | |
tts_processor = None | |
tts_model = None | |
tts_vocoder = None | |
speaker_embeddings = None | |
whisper_processor = None # New: Global for Whisper processor | |
whisper_model = None # New: Global for Whisper model | |
# --- Load All Models Function --- | |
# Decorate with @spaces.GPU to signal this function needs GPU access | |
def load_models(): | |
""" | |
Loads the language model, tokenizer, TTS models, speaker embeddings, | |
and STT (Whisper) models from Hugging Face Hub. | |
This function will be called once when the Gradio app starts up. | |
""" | |
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings | |
global whisper_processor, whisper_model | |
if (tokenizer is not None and llm_model is not None and tts_model is not None and | |
whisper_processor is not None and whisper_model is not None): | |
print("All models and tokenizers/processors already loaded.") | |
return | |
hf_token = os.environ.get("HF_TOKEN") | |
# Load Language Model (LLM) | |
print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") | |
print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...") | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
HUGGINGFACE_MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
device_map="auto", | |
token=hf_token | |
) | |
llm_model.eval() | |
print("LLM model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading LLM model or tokenizer: {e}") | |
raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.") | |
# Load TTS models | |
print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}") | |
try: | |
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token) | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token) | |
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token) | |
print("Loading speaker embeddings for TTS...") | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token) | |
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
device = llm_model.device if llm_model else 'cpu' | |
tts_model.to(device) | |
tts_vocoder.to(device) | |
speaker_embeddings = speaker_embeddings.to(device) | |
print(f"TTS models and speaker embeddings loaded successfully to device: {device}.") | |
except Exception as e: | |
print(f"Error loading TTS models or speaker embeddings: {e}") | |
tts_processor = None | |
tts_model = None | |
tts_vocoder = None | |
speaker_embeddings = None | |
raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.") | |
# Load STT (Whisper) model | |
print(f"Loading STT (Whisper) processor and model from: {STT_MODEL_ID}") | |
try: | |
whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token) | |
whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token) | |
device = llm_model.device if llm_model else 'cpu' # Use the same device as LLM | |
whisper_model.to(device) | |
print(f"STT (Whisper) model loaded successfully to device: {device}.") | |
except Exception as e: | |
print(f"Error loading STT (Whisper) model or processor: {e}") | |
whisper_processor = None | |
whisper_model = None | |
raise RuntimeError("Failed to load STT (Whisper) components. Check model ID and internet connection.") | |
# --- Generate Response and Audio Function --- | |
# Decorate with @spaces.GPU as this function performs GPU-intensive inference | |
def generate_response_and_audio( | |
message: str, # Current user message | |
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content') | |
) -> tuple: # Returns (updated_history, audio_file_path) | |
""" | |
Generates a text response from the loaded LLM and then converts it to audio | |
using the loaded TTS model. | |
""" | |
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings | |
# Initialize all models if not already loaded | |
if tokenizer is None or llm_model is None or tts_model is None: | |
load_models() | |
if tokenizer is None or llm_model is None: # Check LLM loading status | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."}) | |
return history, None | |
# --- 1. Generate Text Response (LLM) --- | |
messages = history | |
messages.append({"role": "user", "content": message}) | |
try: | |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
except Exception as e: | |
print(f"Error applying chat template: {e}") | |
input_text = "" | |
for item in history: | |
if item["role"] == "user": | |
input_text += f"User: {item['content']}\n" | |
elif item["role"] == "assistant": | |
input_text += f"Assistant: {item['content']}\n" | |
input_text += f"User: {message}\nAssistant:" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device) | |
with torch.no_grad(): | |
output_ids = llm_model.generate( | |
input_ids, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=DO_SAMPLE, | |
temperature=TEMPERATURE, | |
top_k=TOP_K, | |
top_p=TOP_P, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_token_ids = output_ids[0][input_ids.shape[-1]:] | |
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() | |
# --- 2. Generate Audio from Response (TTS) --- | |
audio_path = None | |
if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None: | |
try: | |
device = llm_model.device if llm_model else 'cpu' | |
tts_model.to(device) | |
tts_vocoder.to(device) | |
speaker_embeddings = speaker_embeddings.to(device) | |
tts_inputs = tts_processor( | |
text=generated_text, | |
return_tensors="pt", | |
max_length=550, | |
truncation=True | |
).to(device) | |
with torch.no_grad(): | |
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
audio_path = tmp_file.name | |
sf.write(audio_path, speech.cpu().numpy(), samplerate=16000) | |
print(f"Audio saved to: {audio_path}") | |
except Exception as e: | |
print(f"Error generating audio: {e}") | |
audio_path = None | |
else: | |
print("TTS components not loaded. Skipping audio generation.") | |
# --- 3. Update Chat History --- | |
history.append({"role": "assistant", "content": generated_text}) | |
return history, audio_path | |
# --- Transcribe Audio Function (NEW) --- | |
# This function also needs GPU access for Whisper inference | |
def transcribe_audio(audio_filepath): | |
""" | |
Transcribes an audio file using the loaded Whisper model. | |
Handles audio files of varying lengths. | |
""" | |
global whisper_processor, whisper_model | |
if whisper_processor is None or whisper_model is None: | |
load_models() # Attempt to load if not already loaded | |
if whisper_processor is None or whisper_model is None: | |
return "Error: Speech-to-Text model not loaded. Please check logs." | |
if audio_filepath is None: | |
return "No audio input provided for transcription." | |
print(f"Transcribing audio from: {audio_filepath}") | |
try: | |
# Load audio file and resample to 16kHz (Whisper's required sample rate) | |
audio, sample_rate = librosa.load(audio_filepath, sr=16000) | |
# Process audio input for the Whisper model | |
input_features = whisper_processor( | |
audio, | |
sampling_rate=sample_rate, | |
return_tensors="pt" | |
).input_features.to(whisper_model.device) | |
# Generate transcription IDs | |
predicted_ids = whisper_model.generate(input_features) | |
# Decode the IDs to text | |
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
print(f"Transcription: {transcription}") | |
return transcription | |
except Exception as e: | |
print(f"Error during transcription: {e}") | |
return f"Transcription failed: {e}" | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot with Voice Input & Output | |
Type your message or speak into the microphone to chat with the model. | |
The chatbot's response will be spoken, and your audio input can be transcribed! | |
""" | |
) | |
with gr.Tab("Chat with Voice"): | |
chatbot = gr.Chatbot(label="Conversation", type='messages') | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
scale=4 | |
) | |
submit_button = gr.Button("Send", scale=1) | |
audio_output = gr.Audio( | |
label="Listen to Response", | |
autoplay=True, | |
interactive=False | |
) | |
submit_button.click( | |
fn=generate_response_and_audio, | |
inputs=[text_input, chatbot], | |
outputs=[chatbot, audio_output], | |
queue=True | |
) | |
text_input.submit( | |
fn=generate_response_and_audio, | |
inputs=[text_input, chatbot], | |
outputs=[chatbot, audio_output], | |
queue=True | |
) | |
with gr.Tab("Audio Transcription"): | |
stt_audio_input = gr.Audio( | |
type="filepath", | |
label="Upload Audio or Record from Microphone", | |
# Removed 'microphone=True' and 'source' as they cause TypeError with older Gradio versions | |
format="wav" # Ensure consistent format | |
) | |
transcribe_button = gr.Button("Transcribe Audio") | |
transcribed_text_output = gr.Textbox( | |
label="Transcription", | |
placeholder="Transcription will appear here...", | |
interactive=False | |
) | |
transcribe_button.click( | |
fn=transcribe_audio, | |
inputs=[stt_audio_input], | |
outputs=[transcribed_text_output], | |
queue=True | |
) | |
# Clear button for the entire interface | |
def clear_all(): | |
return [], "", None, None, "" # Clear chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output | |
clear_button = gr.Button("Clear All") | |
clear_button.click( | |
clear_all, | |
inputs=None, | |
outputs=[chatbot, text_input, audio_output, stt_audio_input, transcribed_text_output] | |
) | |
# Load all models when the app starts up | |
load_models() | |
# Launch the Gradio app | |
demo.queue().launch() | |