ajsbsd's picture
Update app.py
643e711 verified
raw
history blame
9.4 kB
import gradio as gr
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan,
WhisperProcessor, WhisperForConditionalGeneration
)
from datasets import load_dataset
import os
import spaces
import tempfile
import soundfile as sf
import librosa
import yaml
# ================== Configuration ==================
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
TORCH_DTYPE = torch.bfloat16
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.95
TTS_MODEL_ID = "microsoft/speecht5_tts"
TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
STT_MODEL_ID = "openai/whisper-small"
# ================== Global Variables ==================
tokenizer = None
llm_model = None
tts_processor = None
tts_model = None
tts_vocoder = None
speaker_embeddings = None
whisper_processor = None
whisper_model = None
first_load = True
# ================== UI Helpers ==================
def generate_pretty_html(data):
html = """
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: auto;
background-color: #f9f9f9; border-radius: 10px; padding: 20px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);">
<h2 style="color: #2c3e50; border-bottom: 2px solid #ddd; padding-bottom: 10px;">Model Info</h2>
"""
for key, value in data.items():
html += f"""
<div style="margin-bottom: 12px;">
<strong style="color: #34495e; display: inline-block; width: 160px;">{key}:</strong>
<span style="color: #2c3e50;">{value}</span>
</div>
"""
html += "</div>"
return html
def load_config():
with open("config.yaml", "r", encoding="utf-8") as f:
return yaml.safe_load(f) # Loads only the first document
def render_modern_info():
try:
config = load_config()
return generate_pretty_html(config)
except Exception as e:
return f"<div style='color: red;'>Error loading config: {str(e)}</div>"
def load_readme():
with open("README.md", "r", encoding="utf-8") as f:
return f.read()
# ================== Helper Functions ==================
def split_text_into_chunks(text, max_chars=400):
sentences = text.replace("...", ".").split(". ")
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 2 < max_chars:
current_chunk += ". " + sentence if current_chunk else sentence
else:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return [f"{chunk}." for chunk in chunks if chunk.strip()]
# ================== Model Loading ==================
@spaces.GPU
def load_models():
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings, whisper_processor, whisper_model
hf_token = os.environ.get("HF_TOKEN")
# LLM
if tokenizer is None or llm_model is None:
try:
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
llm_model = AutoModelForCausalLM.from_pretrained(
HUGGINGFACE_MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto",
token=hf_token
).eval()
print("LLM loaded successfully.")
except Exception as e:
print(f"Error loading LLM: {e}")
# TTS
if tts_processor is None or tts_model is None or tts_vocoder is None:
try:
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0)
device = llm_model.device if llm_model else 'cpu'
tts_model.to(device)
tts_vocoder.to(device)
speaker_embeddings = speaker_embeddings.to(device)
print("TTS models loaded.")
except Exception as e:
print(f"Error loading TTS: {e}")
# STT
if whisper_processor is None or whisper_model is None:
try:
whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
device = llm_model.device if llm_model else 'cpu'
whisper_model.to(device)
print("Whisper loaded.")
except Exception as e:
print(f"Error loading Whisper: {e}")
# ================== Chat & Audio Functions ==================
@spaces.GPU
def generate_response_and_audio(message, history):
global first_load
if first_load:
load_models()
first_load = False
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
if tokenizer is None or llm_model is None:
return [{"role": "assistant", "content": "Error: LLM not loaded."}], None
messages = history.copy()
messages.append({"role": "user", "content": message})
try:
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
input_text = ""
for item in history:
input_text += f"{item['role'].capitalize()}: {item['content']}\n"
input_text += f"User: {message}\nAssistant:"
try:
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
output_ids = llm_model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=MAX_NEW_TOKENS,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
top_k=TOP_K,
top_p=TOP_P,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
except Exception as e:
print(f"LLM error: {e}")
return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None
audio_path = None
if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]:
try:
device = llm_model.device
text_chunks = split_text_into_chunks(generated_text)
full_speech = []
for chunk in text_chunks:
tts_inputs = tts_processor(text=chunk, return_tensors="pt", max_length=512, truncation=True).to(device)
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
full_speech.append(speech.cpu())
full_speech_tensor = torch.cat(full_speech, dim=0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
audio_path = tmp_file.name
sf.write(audio_path, full_speech_tensor.numpy(), samplerate=16000)
except Exception as e:
print(f"TTS error: {e}")
return history + [{"role": "assistant", "content": generated_text}], audio_path
@spaces.GPU
def transcribe_audio(filepath):
global first_load
if first_load:
load_models()
first_load = False
global whisper_processor, whisper_model
if whisper_model is None:
return "Whisper model not loaded."
try:
audio, sr = librosa.load(filepath, sr=16000)
inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device)
outputs = whisper_model.generate(inputs)
return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0]
except Exception as e:
return f"Transcription failed: {e}"
# ================== Gradio UI ==================
with gr.Blocks() as demo:
gr.Markdown("# Qwen2.5 Chatbot with Voice Input/Output")
with gr.Tab("Chat"):
chatbot = gr.Chatbot(type='messages')
text_input = gr.Textbox(placeholder="Type your message...")
audio_output = gr.Audio(label="Response Audio", autoplay=True)
text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output])
with gr.Tab("Transcribe"):
audio_input = gr.Audio(type="filepath", label="Upload Audio")
transcribed = gr.Textbox(label="Transcription")
audio_input.upload(transcribe_audio, audio_input, transcribed)
clear_btn = gr.Button("Clear All")
clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output])
gr.Markdown(load_readme())
gr.Markdown("---")
# βœ… Define html_output BEFORE using it
html_output = gr.HTML("<div style='text-align:center; padding: 20px;'>Loading model info...</div>")
# βœ… Now this works!
demo.load(fn=render_modern_info, outputs=html_output)
# ================== Launch App ==================
demo.queue().launch()