File size: 12,173 Bytes
c533b68
d37849f
c533b68
48c079f
d24a68b
fe5270d
c533b68
d24a68b
dd217c7
 
d24a68b
 
c533b68
 
 
dd217c7
4dab15f
d24a68b
21c1f7b
dd217c7
 
 
 
 
 
d37849f
dd217c7
c533b68
 
 
dd217c7
 
c533b68
dd217c7
d37849f
c533b68
4dab15f
 
7f66593
 
 
 
 
4dab15f
7f66593
 
7804f9c
48c079f
 
 
 
 
 
c533b68
7804f9c
 
c533b68
 
 
 
 
 
 
7804f9c
 
 
4dab15f
c533b68
4dab15f
fe5270d
5975221
12443a2
c533b68
 
 
 
 
 
 
 
 
12443a2
fe5270d
c533b68
1ce5fe7
fe5270d
4dab15f
c533b68
 
1c93979
fe5270d
 
 
21c1f7b
c533b68
 
 
 
1c93979
ebf74d7
c533b68
fe5270d
 
 
 
21c1f7b
c533b68
fe5270d
 
 
 
1c93979
ebf74d7
 
 
 
21c1f7b
ebf74d7
 
 
fe5270d
 
4dab15f
a674527
0978fba
48c079f
 
 
 
c533b68
 
 
 
0978fba
c533b68
 
 
9e5b3c3
 
 
 
 
 
 
 
c533b68
1df5e0e
c533b68
dd217c7
d37849f
 
 
 
 
b6584c2
d37849f
48c079f
d37849f
1df5e0e
d37849f
 
dd217c7
 
c533b68
dd217c7
c533b68
 
 
 
 
 
 
 
 
dd217c7
 
 
 
e35df77
a674527
4446bbe
21c1f7b
c533b68
 
4dab15f
 
 
 
 
 
c533b68
 
4dab15f
21c1f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
c533b68
 
 
21c1f7b
 
 
0ad45a6
 
21c1f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dab15f
0ad45a6
 
 
21c1f7b
 
 
 
1c93979
21c1f7b
 
1c93979
21c1f7b
1c93979
ebf74d7
 
 
 
 
 
1c93979
21c1f7b
 
 
 
ebf74d7
21c1f7b
 
 
 
c533b68
1c93979
4dab15f
21c1f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ad45a6
21c1f7b
 
 
1c93979
21c1f7b
 
 
 
 
0ad45a6
21c1f7b
 
 
1c93979
21c1f7b
 
 
 
 
0ad45a6
21c1f7b
 
 
1c93979
21c1f7b
 
 
 
 
 
 
 
dd217c7
d37849f
a674527
 
 
9416b46
a674527
c533b68
b9361ff
c533b68
 
 
9416b46
 
 
 
 
c533b68
9416b46
a674527
 
 
c533b68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python
# ruff: noqa: E402

import json
import tempfile
import os

import click
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio

from importlib.resources import files
from groq import Groq
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer

# Try to import spaces; if available, wrap functions for GPU support.
try:
    import spaces
    USING_SPACES = True
except ImportError:
    USING_SPACES = False


def gpu_decorator(func):
    """
    Decorator that wraps a function with GPU acceleration if running in a Spaces environment.
    """
    if USING_SPACES:
        return spaces.GPU(func)
    return func


# Local package imports
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
    load_vocoder,
    load_model,
    preprocess_ref_audio_text,
    infer_process,
    remove_silence_for_generated_wav,
    save_spectrogram,
)

DEFAULT_TTS_MODEL = "F5-TTS"
DEFAULT_TTS_MODEL_CFG = [
    "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
    "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
    json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]

# Load vocoder and TTS model
vocoder = load_vocoder()


def load_f5tts(
    ckpt_path: str = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
):
    """
    Load the F5-TTS model from the given checkpoint path.
    """
    F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
    return load_model(DiT, F5TTS_model_cfg, ckpt_path)


F5TTS_ema_model = load_f5tts()


@gpu_decorator
def generate_response(messages, apikey):
    """
    Generate a chat response using the Groq API.
    If messages is a string, wrap it as a user message.
    """
    if isinstance(messages, str):
        messages_payload = [{"role": "user", "content": messages}]
    else:
        messages_payload = messages

    client = Groq(api_key=apikey)
    chat_completion = client.chat.completions.create(
        messages=messages_payload,
        model="deepseek-r1-distill-llama-70b",
        stream=False,
    )
    if chat_completion.choices and hasattr(chat_completion.choices[0].message, "content"):
        return chat_completion.choices[0].message.content
    return ""


@gpu_decorator
def process_audio_input(audio_path, text, apikey, history, conv_state):
    """
    Process audio and/or text input from the user:
      - If an audio file is provided, its transcript is obtained.
      - The conversation state and history are updated.
    
    Updated to construct the chat history as a list of dictionaries.
    """
    if not audio_path and not text.strip():
        return history, conv_state, ""

    if audio_path:
        # preprocess_ref_audio_text returns a tuple (audio, transcript)
        _, text = preprocess_ref_audio_text(audio_path, text)

    if not text.strip():
        return history, conv_state, ""

    # Wrap the user input in a dict.
    user_msg = {"role": "user", "content": text}
    conv_state.append(user_msg)
    history.append(user_msg)

    response = generate_response(conv_state, apikey)
    assistant_msg = {"role": "assistant", "content": response}
    conv_state.append(assistant_msg)
    history.append(assistant_msg)
    return history, conv_state, ""


@gpu_decorator
def infer(
    ref_audio_orig,
    ref_text,
    gen_text,
    remove_silence,
    cross_fade_duration: float = 0.15,
    nfe_step: int = 32,
    speed: float = 1,
    show_info=print,
):
    """
    Generate speech audio using the F5-TTS system based on a reference audio/text and generated text.
    """
    if not ref_audio_orig:
        gr.Warning("Please provide reference audio.")
        return gr.update(), gr.update(), ref_text

    if not gen_text.strip():
        gr.Warning("Please enter text to generate.")
        return gr.update(), gr.update(), ref_text

    # Preprocess the reference audio and text.
    ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
    ema_model = F5TTS_ema_model  # Use the default F5-TTS model.

    final_wave, final_sample_rate, combined_spectrogram = infer_process(
        ref_audio,
        ref_text,
        gen_text,
        ema_model,
        vocoder,
        cross_fade_duration=cross_fade_duration,
        nfe_step=nfe_step,
        speed=speed,
        show_info=show_info,
        progress=gr.Progress(),
    )

    if remove_silence:
        # Write the generated waveform to a temporary file.
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            temp_audio_path = f.name
            sf.write(temp_audio_path, final_wave, final_sample_rate)
        # Process the file to remove silence.
        remove_silence_for_generated_wav(temp_audio_path)
        final_wave_tensor, _ = torchaudio.load(temp_audio_path)
        final_wave = final_wave_tensor.squeeze().cpu().numpy()
        os.unlink(temp_audio_path)  # Clean up the temporary file.

    # Save the spectrogram as a temporary PNG file.
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
        spectrogram_path = tmp_spectrogram.name
        save_spectrogram(combined_spectrogram, spectrogram_path)

    return (final_sample_rate, final_wave), spectrogram_path, ref_text


with gr.Blocks() as app:
    gr.Markdown(
        """
# Voice Chat
Have a conversation with an AI using your reference voice! 
1. Upload a reference audio clip and optionally its transcript.
2. Load the chat model.
3. Record your message through your microphone.
4. The AI will respond using the reference voice.
"""
    )

    with gr.Row():
        with gr.Column():
            ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
        with gr.Column():
            with gr.Accordion("Advanced Settings", open=False):
                remove_silence_chat = gr.Checkbox(label="Remove Silences", value=True)
                ref_text_chat = gr.Textbox(
                    label="Reference Text",
                    info="Optional: Leave blank to auto-transcribe",
                    lines=2,
                )
                system_prompt_chat = gr.Textbox(
                    label="System Prompt",
                    value=(
                        "You are not an AI assistant, you are whoever the user says you are. "
                        "You must stay in character. Keep your responses concise since they will be spoken out loud."
                    ),
                    lines=2,
                )

    # Specify the chatbot type to avoid deprecation warnings.
    chatbot_interface = gr.Chatbot(label="Conversation", type="messages")

    with gr.Row():
        with gr.Column():
            audio_input_chat = gr.Microphone(label="Speak your message", type="filepath")
            audio_output_chat = gr.Audio(autoplay=True)
        with gr.Column():
            groq_apikey = gr.Textbox(label="Your Groq API Key")
            text_input_chat = gr.Textbox(label="Type your message", lines=1)
            send_btn_chat = gr.Button("Send Message")
            clear_btn_chat = gr.Button("Clear Conversation")

    # Initialize the conversation state with the system prompt.
    conversation_state = gr.State(
        value=[
            {
                "role": "system",
                "content": (
                    "You are not an AI assistant, you are whoever the user says you are. "
                    "You must stay in character. Keep your responses concise since they will be spoken out loud."
                ),
            }
        ]
    )

    # Create a dummy hidden output to capture the extra (unused) output.
    dummy_output = gr.Textbox(visible=False)

    @gpu_decorator
    def generate_audio_response(history, ref_audio, ref_text, remove_silence):
        """
        Generate an audio response from the last AI message in the conversation.
        Returns the generated audio, the (possibly updated) reference text, and the unchanged chat history.
        """
        if not history or not ref_audio:
            return None, ref_text, history

        # Find the last assistant message in the history.
        last_assistant = None
        for message in reversed(history):
            if message.get("role") == "assistant":
                last_assistant = message
                break
        if last_assistant is None or not last_assistant.get("content", "").strip():
            return None, ref_text, history

        audio_result, _, ref_text_out = infer(
            ref_audio,
            ref_text,
            last_assistant["content"],
            remove_silence,
            cross_fade_duration=0.15,
            speed=1.0,
            show_info=print,
        )
        return audio_result, ref_text_out, history

    def clear_conversation():
        """
        Clear the chat conversation and reset the conversation state.
        """
        initial_state = [
            {
                "role": "system",
                "content": (
                    "You are not an AI assistant, you are whoever the user says you are. "
                    "You must stay in character. Keep your responses concise since they will be spoken out loud."
                ),
            }
        ]
        return [], initial_state

    def update_system_prompt(new_prompt):
        """
        Update the system prompt and reset the conversation.
        """
        initial_state = [{"role": "system", "content": new_prompt}]
        return [], initial_state

    # Set up callbacks so that when recording stops or text is submitted, the processing chain is run.
    audio_input_chat.stop_recording(
        process_audio_input,
        inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state],
        outputs=[chatbot_interface, conversation_state, dummy_output],
    ).then(
        generate_audio_response,
        inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
        outputs=[audio_output_chat, ref_text_chat, chatbot_interface],
    ).then(lambda: None, None, audio_input_chat)

    text_input_chat.submit(
        process_audio_input,
        inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state],
        outputs=[chatbot_interface, conversation_state, dummy_output],
    ).then(
        generate_audio_response,
        inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
        outputs=[audio_output_chat, ref_text_chat, chatbot_interface],
    ).then(lambda: None, None, text_input_chat)

    send_btn_chat.click(
        process_audio_input,
        inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state],
        outputs=[chatbot_interface, conversation_state, dummy_output],
    ).then(
        generate_audio_response,
        inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
        outputs=[audio_output_chat, ref_text_chat, chatbot_interface],
    ).then(lambda: None, None, text_input_chat)

    clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
    system_prompt_chat.change(
        update_system_prompt,
        inputs=system_prompt_chat,
        outputs=[chatbot_interface, conversation_state],
    )


@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option("--root_path", "-r", default=None, type=str, help="Root path for the application")
def main(port, host, share, api, root_path):
    """
    Launch the Gradio app.
    """
    app.queue(api_open=api).launch(
        server_name=host,
        server_port=port,
        share=share,
        show_api=api,
        root_path=root_path,
    )


if __name__ == "__main__":
    main()