File size: 12,253 Bytes
c533b68
d37849f
c533b68
48c079f
dd217c7
d24a68b
fe5270d
c533b68
d24a68b
dd217c7
 
d24a68b
 
c533b68
 
 
dd217c7
4dab15f
d24a68b
c533b68
dd217c7
 
d37849f
dd217c7
 
 
 
d37849f
dd217c7
c533b68
 
 
dd217c7
 
c533b68
dd217c7
d37849f
c533b68
4dab15f
 
7f66593
 
 
 
 
4dab15f
7f66593
 
7804f9c
48c079f
 
 
 
 
 
c533b68
7804f9c
 
c533b68
 
 
 
 
 
 
7804f9c
 
 
4dab15f
c533b68
4dab15f
fe5270d
57a155d
 
 
fe5270d
5975221
fe5270d
c533b68
 
 
 
 
 
 
 
 
fe5270d
c533b68
1ce5fe7
fe5270d
4dab15f
c533b68
 
 
 
fe5270d
 
 
 
c533b68
 
 
 
 
fe5270d
 
 
 
c533b68
 
fe5270d
 
 
 
 
 
 
 
 
 
 
4dab15f
a674527
0978fba
48c079f
 
 
 
c533b68
 
 
 
0978fba
c533b68
 
 
9e5b3c3
 
 
 
 
 
 
 
c533b68
1df5e0e
c533b68
dd217c7
d37849f
 
 
 
 
b6584c2
d37849f
48c079f
d37849f
1df5e0e
d37849f
 
dd217c7
 
c533b68
dd217c7
c533b68
 
 
 
 
 
 
 
 
dd217c7
 
 
 
e35df77
a674527
4446bbe
4dab15f
c533b68
 
4dab15f
 
 
 
 
 
c533b68
 
4dab15f
c533b68
 
0559e57
4dab15f
 
57a155d
4dab15f
 
 
 
c533b68
4dab15f
 
 
 
 
 
 
c533b68
 
 
 
4dab15f
 
 
 
 
 
c533b68
0cc615c
 
57a155d
 
 
c533b68
e35df77
0cc615c
4dab15f
c533b68
4dab15f
 
 
 
c533b68
 
 
 
4dab15f
 
 
 
5975221
bc38247
c533b68
 
 
4dab15f
c533b68
4dab15f
 
 
c533b68
4dab15f
e35df77
4dab15f
 
 
 
 
 
9416b46
4dab15f
9e5b3c3
4dab15f
 
c533b68
 
 
 
 
 
 
 
 
 
 
 
 
4dab15f
 
c533b68
 
 
 
 
4dab15f
c533b68
4dab15f
 
0cc615c
4dab15f
 
 
bc38247
e35df77
9416b46
4dab15f
57a155d
0cc615c
 
 
 
 
 
bc38247
e35df77
9416b46
0cc615c
 
 
 
 
 
 
bc38247
e35df77
9416b46
0cc615c
9416b46
c533b68
 
 
 
 
4dab15f
dd217c7
9416b46
dd217c7
d37849f
a674527
 
 
9416b46
a674527
c533b68
b9361ff
c533b68
 
 
9416b46
 
 
 
 
c533b68
9416b46
a674527
 
 
c533b68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python
# ruff: noqa: E402

import json
import re
import tempfile
import os

import click
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio

from importlib.resources import files
from groq import Groq
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer

# Try to import spaces; if available, set USING_SPACES to True so we can decorate functions for GPU support.
try:
    import spaces

    USING_SPACES = True
except ImportError:
    USING_SPACES = False


def gpu_decorator(func):
    """
    Decorator that wraps a function with GPU acceleration if running in a Spaces environment.
    """
    if USING_SPACES:
        return spaces.GPU(func)
    return func


# Local package imports
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
    load_vocoder,
    load_model,
    preprocess_ref_audio_text,
    infer_process,
    remove_silence_for_generated_wav,
    save_spectrogram,
)

DEFAULT_TTS_MODEL = "F5-TTS"
DEFAULT_TTS_MODEL_CFG = [
    "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
    "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
    json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]

# Load vocoder and TTS model
vocoder = load_vocoder()


def load_f5tts(
    ckpt_path: str = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
):
    """
    Load the F5-TTS model from the given checkpoint path.
    """
    F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
    return load_model(DiT, F5TTS_model_cfg, ckpt_path)


F5TTS_ema_model = load_f5tts()


def load_groq(apikey):
    client = Groq(api_key=apikey)
    return client = Groq(api_key=apikey)

@gpu_decorator
def generate_response(messages):
    """
    Generate a chat response using the Groq API.
    If messages is a string, wrap it as a user message.
    """
    if isinstance(messages, str):
        messages_payload = [{"role": "user", "content": messages}]
    else:
        messages_payload = messages

    chat_completion = client.chat.completions.create(
        messages=messages_payload,
        model="deepseek-r1-distill-llama-70b",
        stream=False,
    )
    # Check that we got a valid response.
    if chat_completion.choices and hasattr(chat_completion.choices[0].message, "content"):
        return chat_completion.choices[0].message.content
    return ""


@gpu_decorator
def process_audio_input(audio_path, text, history, conv_state):
    """
    Process audio and/or text input from the user:
      - If an audio file is provided, its transcript is obtained.
      - The conversation state and history are updated.
    """
    if not audio_path and not text.strip():
        return history, conv_state, ""

    if audio_path:
        # preprocess_ref_audio_text returns a tuple (audio, transcript).
        _, text = preprocess_ref_audio_text(audio_path, text)

    if not text.strip():
        return history, conv_state, ""

    conv_state.append({"role": "user", "content": text})
    history.append((text, None))
    response = generate_response(conv_state)
    conv_state.append({"role": "assistant", "content": response})
    history[-1] = (text, response)
    return history, conv_state, ""


@gpu_decorator
def infer(
    ref_audio_orig,
    ref_text,
    gen_text,
    remove_silence,
    cross_fade_duration: float = 0.15,
    nfe_step: int = 32,
    speed: float = 1,
    show_info=print,
):
    """
    Generate speech audio using the F5-TTS system based on a reference audio/text and generated text.
    """
    if not ref_audio_orig:
        gr.Warning("Please provide reference audio.")
        return gr.update(), gr.update(), ref_text

    if not gen_text.strip():
        gr.Warning("Please enter text to generate.")
        return gr.update(), gr.update(), ref_text

    # Preprocess the reference audio and text.
    ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
    ema_model = F5TTS_ema_model  # Use the default F5-TTS model.

    final_wave, final_sample_rate, combined_spectrogram = infer_process(
        ref_audio,
        ref_text,
        gen_text,
        ema_model,
        vocoder,
        cross_fade_duration=cross_fade_duration,
        nfe_step=nfe_step,
        speed=speed,
        show_info=show_info,
        progress=gr.Progress(),
    )

    if remove_silence:
        # Write the generated waveform to a temporary file.
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            temp_audio_path = f.name
            sf.write(temp_audio_path, final_wave, final_sample_rate)
        # Process the file to remove silence.
        remove_silence_for_generated_wav(temp_audio_path)
        final_wave_tensor, _ = torchaudio.load(temp_audio_path)
        final_wave = final_wave_tensor.squeeze().cpu().numpy()
        os.unlink(temp_audio_path)  # Clean up the temporary file.

    # Save the spectrogram as a temporary PNG file.
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
        spectrogram_path = tmp_spectrogram.name
        save_spectrogram(combined_spectrogram, spectrogram_path)

    return (final_sample_rate, final_wave), spectrogram_path, ref_text


with gr.Blocks() as app_chat:
    gr.Markdown(
        """
# Voice Chat
Have a conversation with an AI using your reference voice! 
1. Upload a reference audio clip and optionally its transcript.
2. Load the chat model.
3. Record your message through your microphone.
4. The AI will respond using the reference voice.
"""
    )

    # The chat interface container is visible only if running in a Spaces environment.
    chat_interface_container = gr.Column(visible=USING_SPACES)

    with chat_interface_container:
        with gr.Row():

            with gr.Column():
                ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
            with gr.Column():
                with gr.Accordion("Advanced Settings", open=False):
                    remove_silence_chat = gr.Checkbox(label="Remove Silences", value=True)
                    ref_text_chat = gr.Textbox(
                        label="Reference Text",
                        info="Optional: Leave blank to auto-transcribe",
                        lines=2,
                    )
                    system_prompt_chat = gr.Textbox(
                        label="System Prompt",
                        value=(
                            "You are not an AI assistant, you are whoever the user says you are. "
                            "You must stay in character. Keep your responses concise since they will be spoken out loud."
                        ),
                        lines=2,
                    )

        chatbot_interface = gr.Chatbot(label="Conversation")
        with gr.Row():
            with gr.Column():
                audio_input_chat = gr.Microphone(label="Speak your message", type="filepath")
                audio_output_chat = gr.Audio(autoplay=True)
            with gr.Column():
                with gr.Row():
                    groq_apikey = gr.Textbox(label="Your Groq API Key")
                    load_api = gr.Button("Load API")
                text_input_chat = gr.Textbox(label="Type your message", lines=1)
                send_btn_chat = gr.Button("Send Message")
                clear_btn_chat = gr.Button("Clear Conversation")

        # Initialize the conversation state with the system prompt.
        conversation_state = gr.State(
            value=[
                {
                    "role": "system",
                    "content": (
                        "You are not an AI assistant, you are whoever the user says you are. "
                        "You must stay in character. Keep your responses concise since they will be spoken out loud."
                    ),
                }
            ]
        )

        @gpu_decorator
        def generate_audio_response(history, ref_audio, ref_text, remove_silence):
            """
            Generate an audio response from the last AI message in the conversation.
            """
            if not history or not ref_audio:
                return None, ref_text

            last_user_message, last_ai_response = history[-1]
            if not last_ai_response:
                return None, ref_text

            audio_result, _, ref_text_out = infer(
                ref_audio,
                ref_text,
                last_ai_response,
                remove_silence,
                cross_fade_duration=0.15,
                speed=1.0,
                show_info=print,
            )
            return audio_result, ref_text_out

        def clear_conversation():
            """
            Clear the chat conversation and reset the conversation state.
            """
            initial_state = [
                {
                    "role": "system",
                    "content": (
                        "You are not an AI assistant, you are whoever the user says you are. "
                        "You must stay in character. Keep your responses concise since they will be spoken out loud."
                    ),
                }
            ]
            return [], initial_state

        def update_system_prompt(new_prompt):
            """
            Update the system prompt and reset the conversation.
            """
            initial_state = [{"role": "system", "content": new_prompt}]
            return [], initial_state

        # Set up callbacks so that when recording stops, or text is submitted, the chain of processing is run.
        audio_input_chat.stop_recording(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, audio_input_chat)

        load_api.click(load_groq, inputs=groq_apikey, outputs=None)
        text_input_chat.submit(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, text_input_chat)

        send_btn_chat.click(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, text_input_chat)

        clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
        system_prompt_chat.change(
            update_system_prompt,
            inputs=system_prompt_chat,
            outputs=[chatbot_interface, conversation_state],
        )


app = app_chat


@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option("--root_path", "-r", default=None, type=str, help="Root path for the application")
def main(port, host, share, api, root_path):
    """
    Launch the Gradio app.
    """
    app.queue(api_open=api).launch(
        server_name=host,
        server_port=port,
        share=share,
        show_api=api,
        root_path=root_path,
    )


if __name__ == "__main__":
    main()