File size: 11,050 Bytes
d37849f
48c079f
dd217c7
d24a68b
03e430f
fe5270d
 
d24a68b
dd217c7
 
d24a68b
 
dd217c7
4dab15f
d24a68b
dd217c7
 
d37849f
dd217c7
 
 
 
d37849f
dd217c7
 
 
 
 
 
d37849f
4dab15f
 
7f66593
 
 
 
 
4dab15f
7f66593
 
7804f9c
 
 
dd217c7
48c079f
 
 
 
 
 
dd217c7
9416b46
7804f9c
 
 
 
 
 
 
4dab15f
 
 
 
fe5270d
 
 
 
 
 
5975221
fe5270d
 
 
 
 
 
 
 
 
 
 
4dab15f
fe5270d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dab15f
a674527
 
0978fba
48c079f
 
 
 
 
 
 
 
 
0978fba
9e5b3c3
 
 
 
 
 
 
 
1df5e0e
9416b46
dd217c7
d37849f
 
 
 
 
b6584c2
d37849f
48c079f
d37849f
1df5e0e
d37849f
 
dd217c7
 
 
d24a68b
 
dd217c7
 
 
 
 
 
 
e35df77
a674527
4446bbe
4dab15f
9416b46
4dab15f
 
 
 
 
 
9416b46
4dab15f
0559e57
 
 
 
 
 
 
 
9416b46
218b3e0
0559e57
840e123
0559e57
bb42a5e
9416b46
0559e57
 
 
 
 
218b3e0
bb42a5e
9416b46
4dab15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc615c
 
 
 
 
 
e35df77
0cc615c
4dab15f
 
 
 
 
 
 
 
 
fe5270d
 
4dab15f
 
5975221
bc38247
4dab15f
 
 
 
 
9154cc9
4dab15f
e35df77
4dab15f
 
 
bc38247
4dab15f
 
 
9416b46
4dab15f
9e5b3c3
4dab15f
 
9416b46
4dab15f
 
9416b46
4dab15f
 
 
0cc615c
4dab15f
 
 
bc38247
e35df77
9416b46
4dab15f
0cc615c
 
 
 
 
 
bc38247
e35df77
9416b46
0cc615c
 
 
 
 
 
 
bc38247
e35df77
9416b46
0cc615c
9416b46
 
4dab15f
dd217c7
9416b46
dd217c7
d37849f
a674527
 
 
9416b46
a674527
9416b46
b9361ff
9416b46
 
 
 
 
 
 
a674527
 
 
9416b46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# ruff: noqa: E402
import json
import re
import tempfile
from importlib.resources import files
from groq import Groq
import os
import click
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer

try:
    import spaces

    USING_SPACES = True
except ImportError:
    USING_SPACES = False


def gpu_decorator(func):
    if USING_SPACES:
        return spaces.GPU(func)
    else:
        return func


from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
    load_vocoder,
    load_model,
    preprocess_ref_audio_text,
    infer_process,
    remove_silence_for_generated_wav,
    save_spectrogram,
)


DEFAULT_TTS_MODEL = "F5-TTS"
tts_model_choice = DEFAULT_TTS_MODEL

DEFAULT_TTS_MODEL_CFG = [
    "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
    "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
    json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]


# Load models
vocoder = load_vocoder()

def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
    F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
    return load_model(DiT, F5TTS_model_cfg, ckpt_path)

F5TTS_ema_model = load_f5tts()
chat_model_state = None
chat_tokenizer_state = None



groq_token = os.getenv("Groq_TOKEN", None)
client = Groq(
    api_key=groq_token,
)

@gpu_decorator
def generate_response(messages):
    """Generate response using Groq"""
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": messages,
            }
        ] if isinstance(messages, str) else messages,
        model="llama-3.3-70b-versatile",
        stream=False,
    )
    return chat_completion.choices[0].message.content # this may need to be fixed


@gpu_decorator
def process_audio_input(audio_path, text, history, conv_state):
    if not audio_path and not text.strip():
        return history, conv_state, ""

    if audio_path:
        text = preprocess_ref_audio_text(audio_path, text)[1]

    if not text.strip():
        return history, conv_state, ""

    conv_state.append({"role": "user", "content": text})
    history.append((text, None))
    response = generate_response(conv_state)
    conv_state.append({"role": "assistant", "content": response})
    history[-1] = (text, response)
    return history, conv_state, ""



@gpu_decorator
def infer(
    ref_audio_orig,
    ref_text,
    gen_text,
    model,
    remove_silence,
    cross_fade_duration=0.15,
    nfe_step=32,
    speed=1,
    show_info=gr.Info,
):
    if not ref_audio_orig:
        gr.Warning("Please provide reference audio.")
        return gr.update(), gr.update(), ref_text

    if not gen_text.strip():
        gr.Warning("Please enter text to generate.")
        return gr.update(), gr.update(), ref_text

    ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
    ema_model = F5TTS_ema_model  # Use F5-TTS by default

    final_wave, final_sample_rate, combined_spectrogram = infer_process(
        ref_audio,
        ref_text,
        gen_text,
        ema_model,
        vocoder,
        cross_fade_duration=cross_fade_duration,
        nfe_step=nfe_step,
        speed=speed,
        show_info=show_info,
        progress=gr.Progress(),
    )

    if remove_silence:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            sf.write(f.name, final_wave, final_sample_rate)
            remove_silence_for_generated_wav(f.name)
            final_wave, _ = torchaudio.load(f.name)
        final_wave = final_wave.squeeze().cpu().numpy()

    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
        spectrogram_path = tmp_spectrogram.name
        save_spectrogram(combined_spectrogram, spectrogram_path)

    return (final_sample_rate, final_wave), spectrogram_path, ref_text


with gr.Blocks() as app_chat:
    gr.Markdown("""
# Voice Chat
Have a conversation with an AI using your reference voice! 
1. Upload a reference audio clip and optionally its transcript.
2. Load the chat model.
3. Record your message through your microphone.
4. The AI will respond using the reference voice.
""")

    if not USING_SPACES:
        load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
        chat_interface_container = gr.Column(visible=False)

        @gpu_decorator
        def load_chat_model():
            global chat_model_state, chat_tokenizer_state
            if chat_model_state is None:
                gr.Info("Loading chat model...")
                model_name = "deepseek-ai/Janus-Pro-7B"
                chat_model_state = AutoModelForCausalLM.from_pretrained(
                    model_name, device_map="auto"
                )
                chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
                gr.Info("Chat model loaded.")
            return gr.update(visible=False), gr.update(visible=True)

        load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
    else:
        chat_interface_container = gr.Column()
        model_name = "deepseek-ai/Janus-Pro-7B"
        chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
        chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)

    with chat_interface_container:
        with gr.Row():
            with gr.Column():
                ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
            with gr.Column():
                with gr.Accordion("Advanced Settings", open=False):
                    remove_silence_chat = gr.Checkbox(
                        label="Remove Silences",
                        value=True,
                    )
                    ref_text_chat = gr.Textbox(
                        label="Reference Text",
                        info="Optional: Leave blank to auto-transcribe",
                        lines=2,
                    )
                    system_prompt_chat = gr.Textbox(
                        label="System Prompt",
                        value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
                        lines=2,
                    )

        chatbot_interface = gr.Chatbot(label="Conversation")
        with gr.Row():
            with gr.Column():
                audio_input_chat = gr.Microphone(
                    label="Speak your message",
                    type="filepath",
                )
                audio_output_chat = gr.Audio(autoplay=True)
            with gr.Column():
                text_input_chat = gr.Textbox(
                    label="Type your message",
                    lines=1,
                )
                send_btn_chat = gr.Button("Send Message")
                clear_btn_chat = gr.Button("Clear Conversation")

        conversation_state = gr.State(
            value=[
                {
                    "role": "system",
                    "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
                }
            ]
        )
        
        


        @gpu_decorator
        def generate_audio_response(history, ref_audio, ref_text, remove_silence):
            if not history or not ref_audio:
                return None

            last_user_message, last_ai_response = history[-1]
            if not last_ai_response:
                return None

            audio_result, _, ref_text_out = infer(
                ref_audio,
                ref_text,
                last_ai_response,
                tts_model_choice,
                remove_silence,
                cross_fade_duration=0.15,
                speed=1.0,
                show_info=print,
            )
            return audio_result, ref_text_out

        def clear_conversation():
            return [], [{"role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud."}]

        def update_system_prompt(new_prompt):
            return [], [{"role": "system", "content": new_prompt}]

        audio_input_chat.stop_recording(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, audio_input_chat)

        text_input_chat.submit(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, text_input_chat)

        send_btn_chat.click(
            process_audio_input,
            inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
            outputs=[chatbot_interface, conversation_state],
        ).then(
            generate_audio_response,
            inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
            outputs=[audio_output_chat, ref_text_chat],
        ).then(lambda: None, None, text_input_chat)

        clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
        system_prompt_chat.change(update_system_prompt, inputs=system_prompt_chat, outputs=[chatbot_interface, conversation_state])


app = app_chat


@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option("--root_path", "-r", default=None, type=str, help='Root path for the application')
def main(port, host, share, api, root_path):
    app.queue(api_open=api).launch(
        server_name=host,
        server_port=port,
        share=share,
        show_api=api,
        root_path=root_path
    )


if __name__ == "__main__":
    main()