Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,23 @@
|
|
|
|
1 |
# ruff: noqa: E402
|
|
|
2 |
import json
|
3 |
import re
|
4 |
import tempfile
|
5 |
-
from importlib.resources import files
|
6 |
-
from groq import Groq
|
7 |
import os
|
|
|
8 |
import click
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import soundfile as sf
|
12 |
import torchaudio
|
|
|
|
|
|
|
13 |
from cached_path import cached_path
|
14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
|
|
|
16 |
try:
|
17 |
import spaces
|
18 |
|
@@ -22,12 +27,15 @@ except ImportError:
|
|
22 |
|
23 |
|
24 |
def gpu_decorator(func):
|
|
|
|
|
|
|
25 |
if USING_SPACES:
|
26 |
return spaces.GPU(func)
|
27 |
-
|
28 |
-
return func
|
29 |
|
30 |
|
|
|
31 |
from f5_tts.model import DiT, UNetT
|
32 |
from f5_tts.infer.utils_infer import (
|
33 |
load_vocoder,
|
@@ -38,58 +46,70 @@ from f5_tts.infer.utils_infer import (
|
|
38 |
save_spectrogram,
|
39 |
)
|
40 |
|
41 |
-
|
42 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
43 |
-
tts_model_choice = DEFAULT_TTS_MODEL
|
44 |
-
|
45 |
DEFAULT_TTS_MODEL_CFG = [
|
46 |
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
47 |
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
48 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
49 |
]
|
50 |
|
51 |
-
|
52 |
-
# Load models
|
53 |
vocoder = load_vocoder()
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
57 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
58 |
|
59 |
-
F5TTS_ema_model = load_f5tts()
|
60 |
-
chat_model_state = None
|
61 |
-
chat_tokenizer_state = None
|
62 |
|
|
|
63 |
|
64 |
|
|
|
65 |
groq_token = os.getenv("Groq_TOKEN", None)
|
66 |
-
client = Groq(
|
67 |
-
|
68 |
-
)
|
69 |
|
70 |
@gpu_decorator
|
71 |
def generate_response(messages):
|
72 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
chat_completion = client.chat.completions.create(
|
74 |
-
messages=
|
75 |
-
{
|
76 |
-
"role": "user",
|
77 |
-
"content": messages,
|
78 |
-
}
|
79 |
-
] if isinstance(messages, str) else messages,
|
80 |
model="llama-3.3-70b-versatile",
|
81 |
stream=False,
|
82 |
)
|
83 |
-
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
@gpu_decorator
|
87 |
def process_audio_input(audio_path, text, history, conv_state):
|
|
|
|
|
|
|
|
|
|
|
88 |
if not audio_path and not text.strip():
|
89 |
return history, conv_state, ""
|
90 |
|
91 |
if audio_path:
|
92 |
-
|
|
|
93 |
|
94 |
if not text.strip():
|
95 |
return history, conv_state, ""
|
@@ -102,19 +122,20 @@ def process_audio_input(audio_path, text, history, conv_state):
|
|
102 |
return history, conv_state, ""
|
103 |
|
104 |
|
105 |
-
|
106 |
@gpu_decorator
|
107 |
def infer(
|
108 |
ref_audio_orig,
|
109 |
ref_text,
|
110 |
gen_text,
|
111 |
-
model,
|
112 |
remove_silence,
|
113 |
-
cross_fade_duration=0.15,
|
114 |
-
nfe_step=32,
|
115 |
-
speed=1,
|
116 |
-
show_info=
|
117 |
):
|
|
|
|
|
|
|
118 |
if not ref_audio_orig:
|
119 |
gr.Warning("Please provide reference audio.")
|
120 |
return gr.update(), gr.update(), ref_text
|
@@ -123,8 +144,9 @@ def infer(
|
|
123 |
gr.Warning("Please enter text to generate.")
|
124 |
return gr.update(), gr.update(), ref_text
|
125 |
|
|
|
126 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
127 |
-
ema_model = F5TTS_ema_model # Use F5-TTS
|
128 |
|
129 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
130 |
ref_audio,
|
@@ -140,12 +162,17 @@ def infer(
|
|
140 |
)
|
141 |
|
142 |
if remove_silence:
|
|
|
143 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
150 |
spectrogram_path = tmp_spectrogram.name
|
151 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
@@ -154,29 +181,27 @@ def infer(
|
|
154 |
|
155 |
|
156 |
with gr.Blocks() as app_chat:
|
157 |
-
gr.Markdown(
|
|
|
158 |
# Voice Chat
|
159 |
Have a conversation with an AI using your reference voice!
|
160 |
1. Upload a reference audio clip and optionally its transcript.
|
161 |
2. Load the chat model.
|
162 |
3. Record your message through your microphone.
|
163 |
4. The AI will respond using the reference voice.
|
164 |
-
"""
|
|
|
165 |
|
166 |
-
if
|
167 |
-
|
168 |
|
169 |
-
|
170 |
with chat_interface_container:
|
171 |
with gr.Row():
|
172 |
with gr.Column():
|
173 |
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
174 |
with gr.Column():
|
175 |
with gr.Accordion("Advanced Settings", open=False):
|
176 |
-
remove_silence_chat = gr.Checkbox(
|
177 |
-
label="Remove Silences",
|
178 |
-
value=True,
|
179 |
-
)
|
180 |
ref_text_chat = gr.Textbox(
|
181 |
label="Reference Text",
|
182 |
info="Optional: Leave blank to auto-transcribe",
|
@@ -184,52 +209,52 @@ Have a conversation with an AI using your reference voice!
|
|
184 |
)
|
185 |
system_prompt_chat = gr.Textbox(
|
186 |
label="System Prompt",
|
187 |
-
value=
|
|
|
|
|
|
|
188 |
lines=2,
|
189 |
)
|
190 |
|
191 |
chatbot_interface = gr.Chatbot(label="Conversation")
|
192 |
with gr.Row():
|
193 |
with gr.Column():
|
194 |
-
audio_input_chat = gr.Microphone(
|
195 |
-
label="Speak your message",
|
196 |
-
type="filepath",
|
197 |
-
)
|
198 |
audio_output_chat = gr.Audio(autoplay=True)
|
199 |
with gr.Column():
|
200 |
-
text_input_chat = gr.Textbox(
|
201 |
-
label="Type your message",
|
202 |
-
lines=1,
|
203 |
-
)
|
204 |
send_btn_chat = gr.Button("Send Message")
|
205 |
clear_btn_chat = gr.Button("Clear Conversation")
|
206 |
|
|
|
207 |
conversation_state = gr.State(
|
208 |
value=[
|
209 |
{
|
210 |
"role": "system",
|
211 |
-
"content":
|
|
|
|
|
|
|
212 |
}
|
213 |
]
|
214 |
)
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
@gpu_decorator
|
220 |
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
|
|
|
|
|
|
|
221 |
if not history or not ref_audio:
|
222 |
-
return None
|
223 |
|
224 |
last_user_message, last_ai_response = history[-1]
|
225 |
if not last_ai_response:
|
226 |
-
return None
|
227 |
|
228 |
audio_result, _, ref_text_out = infer(
|
229 |
ref_audio,
|
230 |
ref_text,
|
231 |
last_ai_response,
|
232 |
-
tts_model_choice,
|
233 |
remove_silence,
|
234 |
cross_fade_duration=0.15,
|
235 |
speed=1.0,
|
@@ -238,11 +263,28 @@ Have a conversation with an AI using your reference voice!
|
|
238 |
return audio_result, ref_text_out
|
239 |
|
240 |
def clear_conversation():
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
def update_system_prompt(new_prompt):
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
|
|
|
246 |
audio_input_chat.stop_recording(
|
247 |
process_audio_input,
|
248 |
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
@@ -274,7 +316,11 @@ Have a conversation with an AI using your reference voice!
|
|
274 |
).then(lambda: None, None, text_input_chat)
|
275 |
|
276 |
clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
|
277 |
-
system_prompt_chat.change(
|
|
|
|
|
|
|
|
|
278 |
|
279 |
|
280 |
app = app_chat
|
@@ -285,16 +331,19 @@ app = app_chat
|
|
285 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
286 |
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
|
287 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
288 |
-
@click.option("--root_path", "-r", default=None, type=str, help=
|
289 |
def main(port, host, share, api, root_path):
|
|
|
|
|
|
|
290 |
app.queue(api_open=api).launch(
|
291 |
server_name=host,
|
292 |
server_port=port,
|
293 |
share=share,
|
294 |
show_api=api,
|
295 |
-
root_path=root_path
|
296 |
)
|
297 |
|
298 |
|
299 |
if __name__ == "__main__":
|
300 |
-
main()
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
# ruff: noqa: E402
|
3 |
+
|
4 |
import json
|
5 |
import re
|
6 |
import tempfile
|
|
|
|
|
7 |
import os
|
8 |
+
|
9 |
import click
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
12 |
import soundfile as sf
|
13 |
import torchaudio
|
14 |
+
|
15 |
+
from importlib.resources import files
|
16 |
+
from groq import Groq
|
17 |
from cached_path import cached_path
|
18 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
19 |
|
20 |
+
# Try to import spaces; if available, set USING_SPACES to True so we can decorate functions for GPU support.
|
21 |
try:
|
22 |
import spaces
|
23 |
|
|
|
27 |
|
28 |
|
29 |
def gpu_decorator(func):
|
30 |
+
"""
|
31 |
+
Decorator that wraps a function with GPU acceleration if running in a Spaces environment.
|
32 |
+
"""
|
33 |
if USING_SPACES:
|
34 |
return spaces.GPU(func)
|
35 |
+
return func
|
|
|
36 |
|
37 |
|
38 |
+
# Local package imports
|
39 |
from f5_tts.model import DiT, UNetT
|
40 |
from f5_tts.infer.utils_infer import (
|
41 |
load_vocoder,
|
|
|
46 |
save_spectrogram,
|
47 |
)
|
48 |
|
|
|
49 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
|
|
|
|
50 |
DEFAULT_TTS_MODEL_CFG = [
|
51 |
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
52 |
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
53 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
54 |
]
|
55 |
|
56 |
+
# Load vocoder and TTS model
|
|
|
57 |
vocoder = load_vocoder()
|
58 |
|
59 |
+
|
60 |
+
def load_f5tts(
|
61 |
+
ckpt_path: str = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Load the F5-TTS model from the given checkpoint path.
|
65 |
+
"""
|
66 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
67 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
68 |
|
|
|
|
|
|
|
69 |
|
70 |
+
F5TTS_ema_model = load_f5tts()
|
71 |
|
72 |
|
73 |
+
# Setup the Groq client for chat completions.
|
74 |
groq_token = os.getenv("Groq_TOKEN", None)
|
75 |
+
client = Groq(api_key=groq_token)
|
76 |
+
|
|
|
77 |
|
78 |
@gpu_decorator
|
79 |
def generate_response(messages):
|
80 |
+
"""
|
81 |
+
Generate a chat response using the Groq API.
|
82 |
+
If messages is a string, wrap it as a user message.
|
83 |
+
"""
|
84 |
+
if isinstance(messages, str):
|
85 |
+
messages_payload = [{"role": "user", "content": messages}]
|
86 |
+
else:
|
87 |
+
messages_payload = messages
|
88 |
+
|
89 |
chat_completion = client.chat.completions.create(
|
90 |
+
messages=messages_payload,
|
|
|
|
|
|
|
|
|
|
|
91 |
model="llama-3.3-70b-versatile",
|
92 |
stream=False,
|
93 |
)
|
94 |
+
# Check that we got a valid response.
|
95 |
+
if chat_completion.choices and hasattr(chat_completion.choices[0].message, "content"):
|
96 |
+
return chat_completion.choices[0].message.content
|
97 |
+
return ""
|
98 |
|
99 |
|
100 |
@gpu_decorator
|
101 |
def process_audio_input(audio_path, text, history, conv_state):
|
102 |
+
"""
|
103 |
+
Process audio and/or text input from the user:
|
104 |
+
- If an audio file is provided, its transcript is obtained.
|
105 |
+
- The conversation state and history are updated.
|
106 |
+
"""
|
107 |
if not audio_path and not text.strip():
|
108 |
return history, conv_state, ""
|
109 |
|
110 |
if audio_path:
|
111 |
+
# preprocess_ref_audio_text returns a tuple (audio, transcript).
|
112 |
+
_, text = preprocess_ref_audio_text(audio_path, text)
|
113 |
|
114 |
if not text.strip():
|
115 |
return history, conv_state, ""
|
|
|
122 |
return history, conv_state, ""
|
123 |
|
124 |
|
|
|
125 |
@gpu_decorator
|
126 |
def infer(
|
127 |
ref_audio_orig,
|
128 |
ref_text,
|
129 |
gen_text,
|
|
|
130 |
remove_silence,
|
131 |
+
cross_fade_duration: float = 0.15,
|
132 |
+
nfe_step: int = 32,
|
133 |
+
speed: float = 1,
|
134 |
+
show_info=print,
|
135 |
):
|
136 |
+
"""
|
137 |
+
Generate speech audio using the F5-TTS system based on a reference audio/text and generated text.
|
138 |
+
"""
|
139 |
if not ref_audio_orig:
|
140 |
gr.Warning("Please provide reference audio.")
|
141 |
return gr.update(), gr.update(), ref_text
|
|
|
144 |
gr.Warning("Please enter text to generate.")
|
145 |
return gr.update(), gr.update(), ref_text
|
146 |
|
147 |
+
# Preprocess the reference audio and text.
|
148 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
149 |
+
ema_model = F5TTS_ema_model # Use the default F5-TTS model.
|
150 |
|
151 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
152 |
ref_audio,
|
|
|
162 |
)
|
163 |
|
164 |
if remove_silence:
|
165 |
+
# Write the generated waveform to a temporary file.
|
166 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
167 |
+
temp_audio_path = f.name
|
168 |
+
sf.write(temp_audio_path, final_wave, final_sample_rate)
|
169 |
+
# Process the file to remove silence.
|
170 |
+
remove_silence_for_generated_wav(temp_audio_path)
|
171 |
+
final_wave_tensor, _ = torchaudio.load(temp_audio_path)
|
172 |
+
final_wave = final_wave_tensor.squeeze().cpu().numpy()
|
173 |
+
os.unlink(temp_audio_path) # Clean up the temporary file.
|
174 |
+
|
175 |
+
# Save the spectrogram as a temporary PNG file.
|
176 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
177 |
spectrogram_path = tmp_spectrogram.name
|
178 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
|
|
181 |
|
182 |
|
183 |
with gr.Blocks() as app_chat:
|
184 |
+
gr.Markdown(
|
185 |
+
"""
|
186 |
# Voice Chat
|
187 |
Have a conversation with an AI using your reference voice!
|
188 |
1. Upload a reference audio clip and optionally its transcript.
|
189 |
2. Load the chat model.
|
190 |
3. Record your message through your microphone.
|
191 |
4. The AI will respond using the reference voice.
|
192 |
+
"""
|
193 |
+
)
|
194 |
|
195 |
+
# The chat interface container is visible only if running in a Spaces environment.
|
196 |
+
chat_interface_container = gr.Column(visible=USING_SPACES)
|
197 |
|
|
|
198 |
with chat_interface_container:
|
199 |
with gr.Row():
|
200 |
with gr.Column():
|
201 |
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
202 |
with gr.Column():
|
203 |
with gr.Accordion("Advanced Settings", open=False):
|
204 |
+
remove_silence_chat = gr.Checkbox(label="Remove Silences", value=True)
|
|
|
|
|
|
|
205 |
ref_text_chat = gr.Textbox(
|
206 |
label="Reference Text",
|
207 |
info="Optional: Leave blank to auto-transcribe",
|
|
|
209 |
)
|
210 |
system_prompt_chat = gr.Textbox(
|
211 |
label="System Prompt",
|
212 |
+
value=(
|
213 |
+
"You are not an AI assistant, you are whoever the user says you are. "
|
214 |
+
"You must stay in character. Keep your responses concise since they will be spoken out loud."
|
215 |
+
),
|
216 |
lines=2,
|
217 |
)
|
218 |
|
219 |
chatbot_interface = gr.Chatbot(label="Conversation")
|
220 |
with gr.Row():
|
221 |
with gr.Column():
|
222 |
+
audio_input_chat = gr.Microphone(label="Speak your message", type="filepath")
|
|
|
|
|
|
|
223 |
audio_output_chat = gr.Audio(autoplay=True)
|
224 |
with gr.Column():
|
225 |
+
text_input_chat = gr.Textbox(label="Type your message", lines=1)
|
|
|
|
|
|
|
226 |
send_btn_chat = gr.Button("Send Message")
|
227 |
clear_btn_chat = gr.Button("Clear Conversation")
|
228 |
|
229 |
+
# Initialize the conversation state with the system prompt.
|
230 |
conversation_state = gr.State(
|
231 |
value=[
|
232 |
{
|
233 |
"role": "system",
|
234 |
+
"content": (
|
235 |
+
"You are not an AI assistant, you are whoever the user says you are. "
|
236 |
+
"You must stay in character. Keep your responses concise since they will be spoken out loud."
|
237 |
+
),
|
238 |
}
|
239 |
]
|
240 |
)
|
|
|
|
|
|
|
241 |
|
242 |
@gpu_decorator
|
243 |
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
|
244 |
+
"""
|
245 |
+
Generate an audio response from the last AI message in the conversation.
|
246 |
+
"""
|
247 |
if not history or not ref_audio:
|
248 |
+
return None, ref_text
|
249 |
|
250 |
last_user_message, last_ai_response = history[-1]
|
251 |
if not last_ai_response:
|
252 |
+
return None, ref_text
|
253 |
|
254 |
audio_result, _, ref_text_out = infer(
|
255 |
ref_audio,
|
256 |
ref_text,
|
257 |
last_ai_response,
|
|
|
258 |
remove_silence,
|
259 |
cross_fade_duration=0.15,
|
260 |
speed=1.0,
|
|
|
263 |
return audio_result, ref_text_out
|
264 |
|
265 |
def clear_conversation():
|
266 |
+
"""
|
267 |
+
Clear the chat conversation and reset the conversation state.
|
268 |
+
"""
|
269 |
+
initial_state = [
|
270 |
+
{
|
271 |
+
"role": "system",
|
272 |
+
"content": (
|
273 |
+
"You are not an AI assistant, you are whoever the user says you are. "
|
274 |
+
"You must stay in character. Keep your responses concise since they will be spoken out loud."
|
275 |
+
),
|
276 |
+
}
|
277 |
+
]
|
278 |
+
return [], initial_state
|
279 |
|
280 |
def update_system_prompt(new_prompt):
|
281 |
+
"""
|
282 |
+
Update the system prompt and reset the conversation.
|
283 |
+
"""
|
284 |
+
initial_state = [{"role": "system", "content": new_prompt}]
|
285 |
+
return [], initial_state
|
286 |
|
287 |
+
# Set up callbacks so that when recording stops, or text is submitted, the chain of processing is run.
|
288 |
audio_input_chat.stop_recording(
|
289 |
process_audio_input,
|
290 |
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
|
|
316 |
).then(lambda: None, None, text_input_chat)
|
317 |
|
318 |
clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
|
319 |
+
system_prompt_chat.change(
|
320 |
+
update_system_prompt,
|
321 |
+
inputs=system_prompt_chat,
|
322 |
+
outputs=[chatbot_interface, conversation_state],
|
323 |
+
)
|
324 |
|
325 |
|
326 |
app = app_chat
|
|
|
331 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
332 |
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
|
333 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
334 |
+
@click.option("--root_path", "-r", default=None, type=str, help="Root path for the application")
|
335 |
def main(port, host, share, api, root_path):
|
336 |
+
"""
|
337 |
+
Launch the Gradio app.
|
338 |
+
"""
|
339 |
app.queue(api_open=api).launch(
|
340 |
server_name=host,
|
341 |
server_port=port,
|
342 |
share=share,
|
343 |
show_api=api,
|
344 |
+
root_path=root_path,
|
345 |
)
|
346 |
|
347 |
|
348 |
if __name__ == "__main__":
|
349 |
+
main()
|