File size: 3,677 Bytes
748ecaa
 
 
 
 
d5d8bf3
748ecaa
46f1390
 
e5d26e9
748ecaa
46f1390
 
 
d5d8bf3
46f1390
 
 
 
 
 
 
d5d8bf3
46f1390
 
 
d5d8bf3
46f1390
 
 
d5d8bf3
 
46f1390
 
 
 
d5d8bf3
46f1390
b1f1246
46f1390
d5d8bf3
46f1390
 
 
d5d8bf3
46f1390
 
 
 
 
 
 
 
 
 
 
e5d26e9
46f1390
 
748ecaa
d5d8bf3
 
 
e5d26e9
748ecaa
d743fc1
46f1390
 
d743fc1
 
46f1390
 
 
 
 
 
 
 
 
d5d8bf3
46f1390
 
 
 
 
 
 
 
 
 
 
d5d8bf3
 
 
d7ca016
d5d8bf3
 
 
46f1390
 
 
d5d8bf3
46f1390
 
d5d8bf3
46f1390
 
d5d8bf3
46f1390
 
 
 
748ecaa
 
46f1390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torchaudio
import gradio as gr

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes

# Global cache to hold the loaded model
MODEL = None
device = "cuda"

def load_model():
    """
    Loads the Zonos model once and caches it globally.
    Adjust the model name if you want to switch from hybrid to transformer, etc.
    """
    global MODEL
    if MODEL is None:
        model_name = "Zyphra/Zonos-v0.1-hybrid"
        print(f"Loading model: {model_name}")
        MODEL = Zonos.from_pretrained(model_name, device="cuda")
        MODEL = MODEL.requires_grad_(False).eval()
        MODEL.bfloat16()  # optional if your GPU supports bfloat16
        print("Model loaded successfully!")
    return MODEL

def tts(text, speaker_audio, selected_language):
    """
    text: str
    speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
    selected_language: str (e.g., "en-us", "es-es", etc.)
    
    Returns (sample_rate, waveform) for Gradio audio output.
    """
    model = load_model()

    # If no text, return None
    if not text:
        return None

    # If no reference audio, return None
    if speaker_audio is None:
        return None

    # Gradio provides audio in (sample_rate, numpy_array)
    sr, wav_np = speaker_audio

    # Convert to Torch tensor: shape (1, num_samples)
    wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
    if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
        # If shape is transposed, fix it
        wav_tensor = wav_tensor.T

    # Get speaker embedding
    with torch.no_grad():
        spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
        spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)

    # Prepare conditioning dictionary
    cond_dict = make_cond_dict(
        text=text,                   # The text prompt
        speaker=spk_embedding,       # Speaker embedding
        language=selected_language,  # Language from the Dropdown
        device=device,
    )
    conditioning = model.prepare_conditioning(cond_dict)

    # Generate codes
    with torch.no_grad():
        codes = model.generate(conditioning)

    # Decode the codes into raw audio
    wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
    sr_out = model.autoencoder.sampling_rate

    return (sr_out, wav_out.numpy())

def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)")

        with gr.Row():
            text_input = gr.Textbox(
                label="Text Prompt",
                value="Hello from Zonos!",
                lines=3
            )
            ref_audio_input = gr.Audio(
                label="Reference Audio (Speaker Cloning)",
                type="numpy"
            )
        # Add a dropdown for language selection
        language_dropdown = gr.Dropdown(
            label="Language",
            choices=supported_language_codes,
            value="en-us",
            interactive=True
        )

        generate_button = gr.Button("Generate")

        # The output is an audio widget that Gradio will play
        audio_output = gr.Audio(label="Synthesized Output", type="numpy")

        # Bind the generate button: pass text, reference audio, and selected language
        generate_button.click(
            fn=tts,
            inputs=[text_input, ref_audio_input, language_dropdown],
            outputs=audio_output,
        )

    return demo

if __name__ == "__main__":
    demo_app = build_demo()
    demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)