Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoTokenizer, AutoModelForCausalLM, | |
SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, | |
WhisperProcessor, WhisperForConditionalGeneration | |
) | |
from datasets import load_dataset | |
import os | |
import spaces | |
import tempfile | |
import soundfile as sf | |
import librosa | |
import yaml | |
# ================== Configuration ================== | |
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" | |
TORCH_DTYPE = torch.bfloat16 | |
MAX_NEW_TOKENS = 512 | |
DO_SAMPLE = True | |
TEMPERATURE = 0.7 | |
TOP_K = 50 | |
TOP_P = 0.95 | |
TTS_MODEL_ID = "microsoft/speecht5_tts" | |
TTS_VOCODER_ID = "microsoft/speecht5_hifigan" | |
STT_MODEL_ID = "openai/whisper-small" | |
# ================== Global Variables ================== | |
tokenizer = None | |
llm_model = None | |
tts_processor = None | |
tts_model = None | |
tts_vocoder = None | |
speaker_embeddings = None | |
whisper_processor = None | |
whisper_model = None | |
first_load = True | |
### UI Helpers | |
def generate_pretty_html(data): | |
html = """ | |
<div class="font-sans max-w-xl mx-auto bg-gray-800 text-white rounded-lg p-6 shadow-md"> | |
<h2 class="text-xl font-semibold text-white border-b border-gray-600 pb-2 mb-4">Model Info</h2> | |
""" | |
for key, value in data.items(): | |
html += f""" | |
<div class="mb-3"> | |
<strong class="text-blue-400 inline-block w-40">{key}:</strong> | |
<span class="text-gray-300">{value}</span> | |
</div> | |
""" | |
html += "</div>" | |
return html | |
def load_config(): | |
with open("config.yaml", "r", encoding="utf-8") as f: | |
return yaml.safe_load(f) | |
def render_modern_info(): | |
try: | |
config = load_config() | |
return generate_pretty_html(config) | |
except Exception as e: | |
return f"<div style='color: red;'>Error loading config: {str(e)}</div>" | |
def load_readme(): | |
with open("README.md", "r", encoding="utf-8") as f: | |
return f.read() | |
# ================== Helper Functions ================== | |
def split_text_into_chunks(text, max_chars=400): | |
sentences = text.replace("...", ".").split(". ") | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) + 2 < max_chars: | |
current_chunk += ". " + sentence if current_chunk else sentence | |
else: | |
chunks.append(current_chunk) | |
current_chunk = sentence | |
if current_chunk: | |
chunks.append(current_chunk) | |
return [f"{chunk}." for chunk in chunks if chunk.strip()] | |
# ================== Model Loading ================== | |
def load_models(): | |
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings, whisper_processor, whisper_model | |
hf_token = os.environ.get("HF_TOKEN") | |
# LLM | |
if tokenizer is None or llm_model is None: | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
HUGGINGFACE_MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
device_map="auto", | |
token=hf_token | |
).eval() | |
print("LLM loaded successfully.") | |
except Exception as e: | |
print(f"Error loading LLM: {e}") | |
# TTS | |
if tts_processor is None or tts_model is None or tts_vocoder is None: | |
try: | |
tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token) | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token) | |
tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token) | |
embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token) | |
speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0) | |
device = llm_model.device if llm_model else 'cpu' | |
tts_model.to(device) | |
tts_vocoder.to(device) | |
speaker_embeddings = speaker_embeddings.to(device) | |
print("TTS models loaded.") | |
except Exception as e: | |
print(f"Error loading TTS: {e}") | |
# STT | |
if whisper_processor is None or whisper_model is None: | |
try: | |
whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token) | |
whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token) | |
device = llm_model.device if llm_model else 'cpu' | |
whisper_model.to(device) | |
print("Whisper loaded.") | |
except Exception as e: | |
print(f"Error loading Whisper: {e}") | |
# ================== Chat & Audio Functions ================== | |
def generate_response_and_audio(message, history): | |
global first_load | |
if first_load: | |
load_models() | |
first_load = False | |
global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings | |
if tokenizer is None or llm_model is None: | |
return [{"role": "assistant", "content": "Error: LLM not loaded."}], None | |
messages = history.copy() | |
messages.append({"role": "user", "content": message}) | |
try: | |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
except: | |
input_text = "" | |
for item in history: | |
input_text += f"{item['role'].capitalize()}: {item['content']}\n" | |
input_text += f"User: {message}\nAssistant:" | |
try: | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device) | |
output_ids = llm_model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=DO_SAMPLE, | |
temperature=TEMPERATURE, | |
top_k=TOP_K, | |
top_p=TOP_P, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip() | |
except Exception as e: | |
print(f"LLM error: {e}") | |
return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None | |
audio_path = None | |
if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]: | |
try: | |
device = llm_model.device | |
text_chunks = split_text_into_chunks(generated_text) | |
full_speech = [] | |
for chunk in text_chunks: | |
tts_inputs = tts_processor(text=chunk, return_tensors="pt", max_length=512, truncation=True).to(device) | |
speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder) | |
full_speech.append(speech.cpu()) | |
full_speech_tensor = torch.cat(full_speech, dim=0) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
audio_path = tmp_file.name | |
sf.write(audio_path, full_speech_tensor.numpy(), samplerate=16000) | |
except Exception as e: | |
print(f"TTS error: {e}") | |
return history + [{"role": "assistant", "content": generated_text}], audio_path | |
def transcribe_audio(filepath): | |
global first_load | |
if first_load: | |
load_models() | |
first_load = False | |
global whisper_processor, whisper_model | |
if whisper_model is None: | |
return "Whisper model not loaded." | |
try: | |
audio, sr = librosa.load(filepath, sr=16000) | |
inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device) | |
outputs = whisper_model.generate(inputs) | |
return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
except Exception as e: | |
return f"Transcription failed: {e}" | |
# ================== Gradio UI ================== | |
with gr.Blocks(head=""" | |
<script src="https://cdn.tailwindcss.com "></script> | |
""") as demo: | |
gr.Markdown(""" | |
<div class="bg-gray-900 text-white p-4 rounded-lg shadow-md mb-6"> | |
<h1 class="text-2xl font-bold">Qwen2.5 Chatbot with Voice Input/Output</h1> | |
<p class="text-gray-300">Powered by Gradio + TailwindCSS</p> | |
</div> | |
""") | |
with gr.Tab("Chat"): | |
gr.HTML(""" | |
<div class="bg-gray-800 p-4 rounded-lg mb-4"> | |
<label class="block text-gray-300 font-medium mb-2">Chat Interface</label> | |
</div> | |
""") | |
chatbot = gr.Chatbot(type='messages', elem_classes=["bg-gray-800", "text-white"]) | |
text_input = gr.Textbox( | |
placeholder="Type your message...", | |
label="User Input", | |
elem_classes=["bg-gray-700", "text-white", "border-gray-600"] | |
) | |
audio_output = gr.Audio(label="Response Audio", autoplay=True) | |
text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output]) | |
with gr.Tab("Transcribe"): | |
gr.HTML(""" | |
<div class="bg-gray-800 p-4 rounded-lg mb-4"> | |
<label class="block text-gray-300 font-medium mb-2">Audio Transcription</label> | |
</div> | |
""") | |
audio_input = gr.Audio(type="filepath", label="Upload Audio") | |
transcribed = gr.Textbox( | |
label="Transcription", | |
elem_classes=["bg-gray-700", "text-white", "border-gray-600"] | |
) | |
audio_input.upload(transcribe_audio, audio_input, transcribed) | |
clear_btn = gr.Button("Clear All", elem_classes=["bg-gray-600", "hover:bg-gray-500", "text-white", "mt-4"]) | |
clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output]) | |
html_output = gr.HTML(""" | |
<div class="bg-gray-800 text-white p-4 rounded-lg mt-6 text-center"> | |
Loading model info... | |
</div> | |
""") | |
demo.load(fn=render_modern_info, outputs=html_output) | |
# ================== Launch App ================== | |
demo.queue().launch() |