File size: 3,677 Bytes
748ecaa d5d8bf3 748ecaa 46f1390 e5d26e9 748ecaa 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 b1f1246 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 e5d26e9 46f1390 748ecaa d5d8bf3 e5d26e9 748ecaa d743fc1 46f1390 d743fc1 46f1390 d5d8bf3 46f1390 d5d8bf3 d7ca016 d5d8bf3 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 d5d8bf3 46f1390 748ecaa 46f1390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import torchaudio
import gradio as gr
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes
# Global cache to hold the loaded model
MODEL = None
device = "cuda"
def load_model():
"""
Loads the Zonos model once and caches it globally.
Adjust the model name if you want to switch from hybrid to transformer, etc.
"""
global MODEL
if MODEL is None:
model_name = "Zyphra/Zonos-v0.1-hybrid"
print(f"Loading model: {model_name}")
MODEL = Zonos.from_pretrained(model_name, device="cuda")
MODEL = MODEL.requires_grad_(False).eval()
MODEL.bfloat16() # optional if your GPU supports bfloat16
print("Model loaded successfully!")
return MODEL
def tts(text, speaker_audio, selected_language):
"""
text: str
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
selected_language: str (e.g., "en-us", "es-es", etc.)
Returns (sample_rate, waveform) for Gradio audio output.
"""
model = load_model()
# If no text, return None
if not text:
return None
# If no reference audio, return None
if speaker_audio is None:
return None
# Gradio provides audio in (sample_rate, numpy_array)
sr, wav_np = speaker_audio
# Convert to Torch tensor: shape (1, num_samples)
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
# If shape is transposed, fix it
wav_tensor = wav_tensor.T
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = make_cond_dict(
text=text, # The text prompt
speaker=spk_embedding, # Speaker embedding
language=selected_language, # Language from the Dropdown
device=device,
)
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)")
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
# Add a dropdown for language selection
language_dropdown = gr.Dropdown(
label="Language",
choices=supported_language_codes,
value="en-us",
interactive=True
)
generate_button = gr.Button("Generate")
# The output is an audio widget that Gradio will play
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
# Bind the generate button: pass text, reference audio, and selected language
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input, language_dropdown],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
|