NeoPy commited on
Commit
c533b68
·
verified ·
1 Parent(s): 6cb3316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -68
app.py CHANGED
@@ -1,18 +1,23 @@
 
1
  # ruff: noqa: E402
 
2
  import json
3
  import re
4
  import tempfile
5
- from importlib.resources import files
6
- from groq import Groq
7
  import os
 
8
  import click
9
  import gradio as gr
10
  import numpy as np
11
  import soundfile as sf
12
  import torchaudio
 
 
 
13
  from cached_path import cached_path
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
 
 
16
  try:
17
  import spaces
18
 
@@ -22,12 +27,15 @@ except ImportError:
22
 
23
 
24
  def gpu_decorator(func):
 
 
 
25
  if USING_SPACES:
26
  return spaces.GPU(func)
27
- else:
28
- return func
29
 
30
 
 
31
  from f5_tts.model import DiT, UNetT
32
  from f5_tts.infer.utils_infer import (
33
  load_vocoder,
@@ -38,58 +46,70 @@ from f5_tts.infer.utils_infer import (
38
  save_spectrogram,
39
  )
40
 
41
-
42
  DEFAULT_TTS_MODEL = "F5-TTS"
43
- tts_model_choice = DEFAULT_TTS_MODEL
44
-
45
  DEFAULT_TTS_MODEL_CFG = [
46
  "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
47
  "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
48
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
49
  ]
50
 
51
-
52
- # Load models
53
  vocoder = load_vocoder()
54
 
55
- def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
 
 
 
 
 
 
56
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
57
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
58
 
59
- F5TTS_ema_model = load_f5tts()
60
- chat_model_state = None
61
- chat_tokenizer_state = None
62
 
 
63
 
64
 
 
65
  groq_token = os.getenv("Groq_TOKEN", None)
66
- client = Groq(
67
- api_key=groq_token,
68
- )
69
 
70
  @gpu_decorator
71
  def generate_response(messages):
72
- """Generate response using Groq"""
 
 
 
 
 
 
 
 
73
  chat_completion = client.chat.completions.create(
74
- messages=[
75
- {
76
- "role": "user",
77
- "content": messages,
78
- }
79
- ] if isinstance(messages, str) else messages,
80
  model="llama-3.3-70b-versatile",
81
  stream=False,
82
  )
83
- return chat_completion.choices[0].message.content # this may need to be fixed
 
 
 
84
 
85
 
86
  @gpu_decorator
87
  def process_audio_input(audio_path, text, history, conv_state):
 
 
 
 
 
88
  if not audio_path and not text.strip():
89
  return history, conv_state, ""
90
 
91
  if audio_path:
92
- text = preprocess_ref_audio_text(audio_path, text)[1]
 
93
 
94
  if not text.strip():
95
  return history, conv_state, ""
@@ -102,19 +122,20 @@ def process_audio_input(audio_path, text, history, conv_state):
102
  return history, conv_state, ""
103
 
104
 
105
-
106
  @gpu_decorator
107
  def infer(
108
  ref_audio_orig,
109
  ref_text,
110
  gen_text,
111
- model,
112
  remove_silence,
113
- cross_fade_duration=0.15,
114
- nfe_step=32,
115
- speed=1,
116
- show_info=gr.Info,
117
  ):
 
 
 
118
  if not ref_audio_orig:
119
  gr.Warning("Please provide reference audio.")
120
  return gr.update(), gr.update(), ref_text
@@ -123,8 +144,9 @@ def infer(
123
  gr.Warning("Please enter text to generate.")
124
  return gr.update(), gr.update(), ref_text
125
 
 
126
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
127
- ema_model = F5TTS_ema_model # Use F5-TTS by default
128
 
129
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
130
  ref_audio,
@@ -140,12 +162,17 @@ def infer(
140
  )
141
 
142
  if remove_silence:
 
143
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
144
- sf.write(f.name, final_wave, final_sample_rate)
145
- remove_silence_for_generated_wav(f.name)
146
- final_wave, _ = torchaudio.load(f.name)
147
- final_wave = final_wave.squeeze().cpu().numpy()
148
-
 
 
 
 
149
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
150
  spectrogram_path = tmp_spectrogram.name
151
  save_spectrogram(combined_spectrogram, spectrogram_path)
@@ -154,29 +181,27 @@ def infer(
154
 
155
 
156
  with gr.Blocks() as app_chat:
157
- gr.Markdown("""
 
158
  # Voice Chat
159
  Have a conversation with an AI using your reference voice!
160
  1. Upload a reference audio clip and optionally its transcript.
161
  2. Load the chat model.
162
  3. Record your message through your microphone.
163
  4. The AI will respond using the reference voice.
164
- """)
 
165
 
166
- if not USING_SPACES:
167
- chat_interface_container = gr.Column(visible=False)
168
 
169
-
170
  with chat_interface_container:
171
  with gr.Row():
172
  with gr.Column():
173
  ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
174
  with gr.Column():
175
  with gr.Accordion("Advanced Settings", open=False):
176
- remove_silence_chat = gr.Checkbox(
177
- label="Remove Silences",
178
- value=True,
179
- )
180
  ref_text_chat = gr.Textbox(
181
  label="Reference Text",
182
  info="Optional: Leave blank to auto-transcribe",
@@ -184,52 +209,52 @@ Have a conversation with an AI using your reference voice!
184
  )
185
  system_prompt_chat = gr.Textbox(
186
  label="System Prompt",
187
- value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
 
 
 
188
  lines=2,
189
  )
190
 
191
  chatbot_interface = gr.Chatbot(label="Conversation")
192
  with gr.Row():
193
  with gr.Column():
194
- audio_input_chat = gr.Microphone(
195
- label="Speak your message",
196
- type="filepath",
197
- )
198
  audio_output_chat = gr.Audio(autoplay=True)
199
  with gr.Column():
200
- text_input_chat = gr.Textbox(
201
- label="Type your message",
202
- lines=1,
203
- )
204
  send_btn_chat = gr.Button("Send Message")
205
  clear_btn_chat = gr.Button("Clear Conversation")
206
 
 
207
  conversation_state = gr.State(
208
  value=[
209
  {
210
  "role": "system",
211
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
 
 
 
212
  }
213
  ]
214
  )
215
-
216
-
217
-
218
 
219
  @gpu_decorator
220
  def generate_audio_response(history, ref_audio, ref_text, remove_silence):
 
 
 
221
  if not history or not ref_audio:
222
- return None
223
 
224
  last_user_message, last_ai_response = history[-1]
225
  if not last_ai_response:
226
- return None
227
 
228
  audio_result, _, ref_text_out = infer(
229
  ref_audio,
230
  ref_text,
231
  last_ai_response,
232
- tts_model_choice,
233
  remove_silence,
234
  cross_fade_duration=0.15,
235
  speed=1.0,
@@ -238,11 +263,28 @@ Have a conversation with an AI using your reference voice!
238
  return audio_result, ref_text_out
239
 
240
  def clear_conversation():
241
- return [], [{"role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud."}]
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def update_system_prompt(new_prompt):
244
- return [], [{"role": "system", "content": new_prompt}]
 
 
 
 
245
 
 
246
  audio_input_chat.stop_recording(
247
  process_audio_input,
248
  inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
@@ -274,7 +316,11 @@ Have a conversation with an AI using your reference voice!
274
  ).then(lambda: None, None, text_input_chat)
275
 
276
  clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
277
- system_prompt_chat.change(update_system_prompt, inputs=system_prompt_chat, outputs=[chatbot_interface, conversation_state])
 
 
 
 
278
 
279
 
280
  app = app_chat
@@ -285,16 +331,19 @@ app = app_chat
285
  @click.option("--host", "-H", default=None, help="Host to run the app on")
286
  @click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
287
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
288
- @click.option("--root_path", "-r", default=None, type=str, help='Root path for the application')
289
  def main(port, host, share, api, root_path):
 
 
 
290
  app.queue(api_open=api).launch(
291
  server_name=host,
292
  server_port=port,
293
  share=share,
294
  show_api=api,
295
- root_path=root_path
296
  )
297
 
298
 
299
  if __name__ == "__main__":
300
- main()
 
1
+ #!/usr/bin/env python
2
  # ruff: noqa: E402
3
+
4
  import json
5
  import re
6
  import tempfile
 
 
7
  import os
8
+
9
  import click
10
  import gradio as gr
11
  import numpy as np
12
  import soundfile as sf
13
  import torchaudio
14
+
15
+ from importlib.resources import files
16
+ from groq import Groq
17
  from cached_path import cached_path
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
20
+ # Try to import spaces; if available, set USING_SPACES to True so we can decorate functions for GPU support.
21
  try:
22
  import spaces
23
 
 
27
 
28
 
29
  def gpu_decorator(func):
30
+ """
31
+ Decorator that wraps a function with GPU acceleration if running in a Spaces environment.
32
+ """
33
  if USING_SPACES:
34
  return spaces.GPU(func)
35
+ return func
 
36
 
37
 
38
+ # Local package imports
39
  from f5_tts.model import DiT, UNetT
40
  from f5_tts.infer.utils_infer import (
41
  load_vocoder,
 
46
  save_spectrogram,
47
  )
48
 
 
49
  DEFAULT_TTS_MODEL = "F5-TTS"
 
 
50
  DEFAULT_TTS_MODEL_CFG = [
51
  "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
52
  "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
53
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
54
  ]
55
 
56
+ # Load vocoder and TTS model
 
57
  vocoder = load_vocoder()
58
 
59
+
60
+ def load_f5tts(
61
+ ckpt_path: str = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
62
+ ):
63
+ """
64
+ Load the F5-TTS model from the given checkpoint path.
65
+ """
66
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
67
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
68
 
 
 
 
69
 
70
+ F5TTS_ema_model = load_f5tts()
71
 
72
 
73
+ # Setup the Groq client for chat completions.
74
  groq_token = os.getenv("Groq_TOKEN", None)
75
+ client = Groq(api_key=groq_token)
76
+
 
77
 
78
  @gpu_decorator
79
  def generate_response(messages):
80
+ """
81
+ Generate a chat response using the Groq API.
82
+ If messages is a string, wrap it as a user message.
83
+ """
84
+ if isinstance(messages, str):
85
+ messages_payload = [{"role": "user", "content": messages}]
86
+ else:
87
+ messages_payload = messages
88
+
89
  chat_completion = client.chat.completions.create(
90
+ messages=messages_payload,
 
 
 
 
 
91
  model="llama-3.3-70b-versatile",
92
  stream=False,
93
  )
94
+ # Check that we got a valid response.
95
+ if chat_completion.choices and hasattr(chat_completion.choices[0].message, "content"):
96
+ return chat_completion.choices[0].message.content
97
+ return ""
98
 
99
 
100
  @gpu_decorator
101
  def process_audio_input(audio_path, text, history, conv_state):
102
+ """
103
+ Process audio and/or text input from the user:
104
+ - If an audio file is provided, its transcript is obtained.
105
+ - The conversation state and history are updated.
106
+ """
107
  if not audio_path and not text.strip():
108
  return history, conv_state, ""
109
 
110
  if audio_path:
111
+ # preprocess_ref_audio_text returns a tuple (audio, transcript).
112
+ _, text = preprocess_ref_audio_text(audio_path, text)
113
 
114
  if not text.strip():
115
  return history, conv_state, ""
 
122
  return history, conv_state, ""
123
 
124
 
 
125
  @gpu_decorator
126
  def infer(
127
  ref_audio_orig,
128
  ref_text,
129
  gen_text,
 
130
  remove_silence,
131
+ cross_fade_duration: float = 0.15,
132
+ nfe_step: int = 32,
133
+ speed: float = 1,
134
+ show_info=print,
135
  ):
136
+ """
137
+ Generate speech audio using the F5-TTS system based on a reference audio/text and generated text.
138
+ """
139
  if not ref_audio_orig:
140
  gr.Warning("Please provide reference audio.")
141
  return gr.update(), gr.update(), ref_text
 
144
  gr.Warning("Please enter text to generate.")
145
  return gr.update(), gr.update(), ref_text
146
 
147
+ # Preprocess the reference audio and text.
148
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
149
+ ema_model = F5TTS_ema_model # Use the default F5-TTS model.
150
 
151
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
152
  ref_audio,
 
162
  )
163
 
164
  if remove_silence:
165
+ # Write the generated waveform to a temporary file.
166
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
167
+ temp_audio_path = f.name
168
+ sf.write(temp_audio_path, final_wave, final_sample_rate)
169
+ # Process the file to remove silence.
170
+ remove_silence_for_generated_wav(temp_audio_path)
171
+ final_wave_tensor, _ = torchaudio.load(temp_audio_path)
172
+ final_wave = final_wave_tensor.squeeze().cpu().numpy()
173
+ os.unlink(temp_audio_path) # Clean up the temporary file.
174
+
175
+ # Save the spectrogram as a temporary PNG file.
176
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
177
  spectrogram_path = tmp_spectrogram.name
178
  save_spectrogram(combined_spectrogram, spectrogram_path)
 
181
 
182
 
183
  with gr.Blocks() as app_chat:
184
+ gr.Markdown(
185
+ """
186
  # Voice Chat
187
  Have a conversation with an AI using your reference voice!
188
  1. Upload a reference audio clip and optionally its transcript.
189
  2. Load the chat model.
190
  3. Record your message through your microphone.
191
  4. The AI will respond using the reference voice.
192
+ """
193
+ )
194
 
195
+ # The chat interface container is visible only if running in a Spaces environment.
196
+ chat_interface_container = gr.Column(visible=USING_SPACES)
197
 
 
198
  with chat_interface_container:
199
  with gr.Row():
200
  with gr.Column():
201
  ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
202
  with gr.Column():
203
  with gr.Accordion("Advanced Settings", open=False):
204
+ remove_silence_chat = gr.Checkbox(label="Remove Silences", value=True)
 
 
 
205
  ref_text_chat = gr.Textbox(
206
  label="Reference Text",
207
  info="Optional: Leave blank to auto-transcribe",
 
209
  )
210
  system_prompt_chat = gr.Textbox(
211
  label="System Prompt",
212
+ value=(
213
+ "You are not an AI assistant, you are whoever the user says you are. "
214
+ "You must stay in character. Keep your responses concise since they will be spoken out loud."
215
+ ),
216
  lines=2,
217
  )
218
 
219
  chatbot_interface = gr.Chatbot(label="Conversation")
220
  with gr.Row():
221
  with gr.Column():
222
+ audio_input_chat = gr.Microphone(label="Speak your message", type="filepath")
 
 
 
223
  audio_output_chat = gr.Audio(autoplay=True)
224
  with gr.Column():
225
+ text_input_chat = gr.Textbox(label="Type your message", lines=1)
 
 
 
226
  send_btn_chat = gr.Button("Send Message")
227
  clear_btn_chat = gr.Button("Clear Conversation")
228
 
229
+ # Initialize the conversation state with the system prompt.
230
  conversation_state = gr.State(
231
  value=[
232
  {
233
  "role": "system",
234
+ "content": (
235
+ "You are not an AI assistant, you are whoever the user says you are. "
236
+ "You must stay in character. Keep your responses concise since they will be spoken out loud."
237
+ ),
238
  }
239
  ]
240
  )
 
 
 
241
 
242
  @gpu_decorator
243
  def generate_audio_response(history, ref_audio, ref_text, remove_silence):
244
+ """
245
+ Generate an audio response from the last AI message in the conversation.
246
+ """
247
  if not history or not ref_audio:
248
+ return None, ref_text
249
 
250
  last_user_message, last_ai_response = history[-1]
251
  if not last_ai_response:
252
+ return None, ref_text
253
 
254
  audio_result, _, ref_text_out = infer(
255
  ref_audio,
256
  ref_text,
257
  last_ai_response,
 
258
  remove_silence,
259
  cross_fade_duration=0.15,
260
  speed=1.0,
 
263
  return audio_result, ref_text_out
264
 
265
  def clear_conversation():
266
+ """
267
+ Clear the chat conversation and reset the conversation state.
268
+ """
269
+ initial_state = [
270
+ {
271
+ "role": "system",
272
+ "content": (
273
+ "You are not an AI assistant, you are whoever the user says you are. "
274
+ "You must stay in character. Keep your responses concise since they will be spoken out loud."
275
+ ),
276
+ }
277
+ ]
278
+ return [], initial_state
279
 
280
  def update_system_prompt(new_prompt):
281
+ """
282
+ Update the system prompt and reset the conversation.
283
+ """
284
+ initial_state = [{"role": "system", "content": new_prompt}]
285
+ return [], initial_state
286
 
287
+ # Set up callbacks so that when recording stops, or text is submitted, the chain of processing is run.
288
  audio_input_chat.stop_recording(
289
  process_audio_input,
290
  inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
 
316
  ).then(lambda: None, None, text_input_chat)
317
 
318
  clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
319
+ system_prompt_chat.change(
320
+ update_system_prompt,
321
+ inputs=system_prompt_chat,
322
+ outputs=[chatbot_interface, conversation_state],
323
+ )
324
 
325
 
326
  app = app_chat
 
331
  @click.option("--host", "-H", default=None, help="Host to run the app on")
332
  @click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
333
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
334
+ @click.option("--root_path", "-r", default=None, type=str, help="Root path for the application")
335
  def main(port, host, share, api, root_path):
336
+ """
337
+ Launch the Gradio app.
338
+ """
339
  app.queue(api_open=api).launch(
340
  server_name=host,
341
  server_port=port,
342
  share=share,
343
  show_api=api,
344
+ root_path=root_path,
345
  )
346
 
347
 
348
  if __name__ == "__main__":
349
+ main()