NeoPy commited on
Commit
0ef0349
·
verified ·
1 Parent(s): 7f72ad0

Delete gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +0 -824
gradio_app.py DELETED
@@ -1,824 +0,0 @@
1
- import os
2
- import re
3
- import torch
4
- import torchaudio
5
- import gradio as gr
6
- import numpy as np
7
- import tempfile
8
- from einops import rearrange
9
- from vocos import Vocos
10
- from pydub import AudioSegment, silence
11
- from model import CFM, UNetT, DiT, MMDiT
12
- from cached_path import cached_path
13
- from model.utils import (
14
- load_checkpoint,
15
- get_tokenizer,
16
- convert_char_to_pinyin,
17
- save_spectrogram,
18
- )
19
- from transformers import pipeline
20
- import librosa
21
- import click
22
- import soundfile as sf
23
-
24
- try:
25
- import spaces
26
- USING_SPACES = True
27
- except ImportError:
28
- USING_SPACES = False
29
-
30
- def gpu_decorator(func):
31
- if USING_SPACES:
32
- return spaces.GPU(func)
33
- else:
34
- return func
35
-
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
- device = (
50
- "cuda"
51
- if torch.cuda.is_available()
52
- else "mps" if torch.backends.mps.is_available() else "cpu"
53
- )
54
-
55
- print(f"Using {device} device")
56
-
57
- pipe = pipeline(
58
- "automatic-speech-recognition",
59
- model="openai/whisper-large-v3-turbo",
60
- torch_dtype=torch.float16,
61
- device=device,
62
- )
63
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
-
65
- # --------------------- Settings -------------------- #
66
-
67
- target_sample_rate = 24000
68
- n_mel_channels = 100
69
- hop_length = 256
70
- target_rms = 0.1
71
- nfe_step = 32 # 16, 32
72
- cfg_strength = 2.0
73
- ode_method = "euler"
74
- sway_sampling_coef = -1.0
75
- speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
- fix_duration = None
78
-
79
-
80
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
- model = CFM(
85
- transformer=model_cls(
86
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
- ),
88
- mel_spec_kwargs=dict(
89
- target_sample_rate=target_sample_rate,
90
- n_mel_channels=n_mel_channels,
91
- hop_length=hop_length,
92
- ),
93
- odeint_kwargs=dict(
94
- method=ode_method,
95
- ),
96
- vocab_char_map=vocab_char_map,
97
- ).to(device)
98
-
99
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
-
101
- return model
102
-
103
-
104
- # load models
105
- F5TTS_model_cfg = dict(
106
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
- )
108
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
-
110
- F5TTS_ema_model = load_model(
111
- "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
- )
113
- E2TTS_ema_model = load_model(
114
- "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
- )
116
-
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
-
154
- for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
- else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
-
213
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
- if exp_name == "F5-TTS":
215
- ema_model = F5TTS_ema_model
216
- elif exp_name == "E2-TTS":
217
- ema_model = E2TTS_ema_model
218
-
219
- audio, sr = ref_audio
220
- if audio.shape[0] > 1:
221
- audio = torch.mean(audio, dim=0, keepdim=True)
222
-
223
- rms = torch.sqrt(torch.mean(torch.square(audio)))
224
- if rms < target_rms:
225
- audio = audio * target_rms / rms
226
- if sr != target_sample_rate:
227
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
- audio = resampler(audio)
229
- audio = audio.to(device)
230
-
231
- generated_waves = []
232
- spectrograms = []
233
-
234
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
- # Prepare the text
236
- if len(ref_text[-1].encode('utf-8')) == 1:
237
- ref_text = ref_text + " "
238
- text_list = [ref_text + gen_text]
239
- final_text_list = convert_char_to_pinyin(text_list)
240
-
241
- # Calculate duration
242
- ref_audio_len = audio.shape[-1] // hop_length
243
- zh_pause_punc = r"。,、;:?!"
244
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
-
248
- # inference
249
- with torch.inference_mode():
250
- generated, _ = ema_model.sample(
251
- cond=audio,
252
- text=final_text_list,
253
- duration=duration,
254
- steps=nfe_step,
255
- cfg_strength=cfg_strength,
256
- sway_sampling_coef=sway_sampling_coef,
257
- )
258
-
259
- generated = generated[:, ref_audio_len:, :]
260
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
- generated_wave = vocos.decode(generated_mel_spec.cpu())
262
- if rms < target_rms:
263
- generated_wave = generated_wave * rms / target_rms
264
-
265
- # wav -> numpy
266
- generated_wave = generated_wave.squeeze().cpu().numpy()
267
-
268
- generated_waves.append(generated_wave)
269
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
-
271
- # Combine all generated waves
272
- final_wave = np.concatenate(generated_waves)
273
-
274
- # Remove silence
275
- if remove_silence:
276
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
- sf.write(f.name, final_wave, target_sample_rate)
278
- aseg = AudioSegment.from_file(f.name)
279
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
- non_silent_wave = AudioSegment.silent(duration=0)
281
- for non_silent_seg in non_silent_segs:
282
- non_silent_wave += non_silent_seg
283
- aseg = non_silent_wave
284
- aseg.export(f.name, format="wav")
285
- final_wave, _ = torchaudio.load(f.name)
286
- final_wave = final_wave.squeeze().cpu().numpy()
287
-
288
- # Create a combined spectrogram
289
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
-
291
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
- spectrogram_path = tmp_spectrogram.name
293
- save_spectrogram(combined_spectrogram, spectrogram_path)
294
-
295
- return (target_sample_rate, final_wave), spectrogram_path
296
-
297
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
- if not custom_split_words.strip():
299
- custom_words = [word.strip() for word in custom_split_words.split(',')]
300
- global SPLIT_WORDS
301
- SPLIT_WORDS = custom_words
302
-
303
- print(gen_text)
304
-
305
- gr.Info("Converting audio...")
306
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
- aseg = AudioSegment.from_file(ref_audio_orig)
308
-
309
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
- non_silent_wave = AudioSegment.silent(duration=0)
311
- for non_silent_seg in non_silent_segs:
312
- non_silent_wave += non_silent_seg
313
- aseg = non_silent_wave
314
-
315
- audio_duration = len(aseg)
316
- if audio_duration > 15000:
317
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
- aseg = aseg[:15000]
319
- aseg.export(f.name, format="wav")
320
- ref_audio = f.name
321
-
322
- if not ref_text.strip():
323
- gr.Info("No reference text provided, transcribing reference audio...")
324
- ref_text = pipe(
325
- ref_audio,
326
- chunk_length_s=30,
327
- batch_size=128,
328
- generate_kwargs={"task": "transcribe"},
329
- return_timestamps=False,
330
- )["text"].strip()
331
- gr.Info("Finished transcription")
332
- else:
333
- gr.Info("Using custom reference text...")
334
-
335
- # Split the input text into batches
336
- audio, sr = torchaudio.load(ref_audio)
337
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
- print('ref_text', ref_text)
340
- for i, gen_text in enumerate(gen_text_batches):
341
- print(f'gen_text {i}', gen_text)
342
-
343
- gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
-
346
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
- # Split the script into speaker blocks
348
- speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
- speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
-
351
- generated_audio_segments = []
352
-
353
- for i in range(0, len(speaker_blocks), 2):
354
- speaker = speaker_blocks[i]
355
- text = speaker_blocks[i+1].strip()
356
-
357
- # Determine which speaker is talking
358
- if speaker == speaker1_name:
359
- ref_audio = ref_audio1
360
- ref_text = ref_text1
361
- elif speaker == speaker2_name:
362
- ref_audio = ref_audio2
363
- ref_text = ref_text2
364
- else:
365
- continue # Skip if the speaker is neither speaker1 nor speaker2
366
-
367
- # Generate audio for this block
368
- audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
-
370
- # Convert the generated audio to a numpy array
371
- sr, audio_data = audio
372
-
373
- # Save the audio data as a WAV file
374
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
- sf.write(temp_file.name, audio_data, sr)
376
- audio_segment = AudioSegment.from_wav(temp_file.name)
377
-
378
- generated_audio_segments.append(audio_segment)
379
-
380
- # Add a short pause between speakers
381
- pause = AudioSegment.silent(duration=500) # 500ms pause
382
- generated_audio_segments.append(pause)
383
-
384
- # Concatenate all audio segments
385
- final_podcast = sum(generated_audio_segments)
386
-
387
- # Export the final podcast
388
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
- podcast_path = temp_file.name
390
- final_podcast.export(podcast_path, format="wav")
391
-
392
- return podcast_path
393
-
394
- def parse_speechtypes_text(gen_text):
395
- # Pattern to find (Emotion)
396
- pattern = r'\((.*?)\)'
397
-
398
- # Split the text by the pattern
399
- tokens = re.split(pattern, gen_text)
400
-
401
- segments = []
402
-
403
- current_emotion = 'Regular'
404
-
405
- for i in range(len(tokens)):
406
- if i % 2 == 0:
407
- # This is text
408
- text = tokens[i].strip()
409
- if text:
410
- segments.append({'emotion': current_emotion, 'text': text})
411
- else:
412
- # This is emotion
413
- emotion = tokens[i].strip()
414
- current_emotion = emotion
415
-
416
- return segments
417
-
418
- def update_speed(new_speed):
419
- global speed
420
- speed = new_speed
421
- return f"Speed set to: {speed}"
422
-
423
- with gr.Blocks() as app_credits:
424
- gr.Markdown("""
425
- # Credits
426
-
427
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
- * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
- """)
430
- with gr.Blocks() as app_tts:
431
- gr.Markdown("# Batched TTS")
432
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
- model_choice = gr.Radio(
435
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
- )
437
- generate_btn = gr.Button("Synthesize", variant="primary")
438
- with gr.Accordion("Advanced Settings", open=False):
439
- ref_text_input = gr.Textbox(
440
- label="Reference Text",
441
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
- lines=2,
443
- )
444
- remove_silence = gr.Checkbox(
445
- label="Remove Silences",
446
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
- value=True,
448
- )
449
- split_words_input = gr.Textbox(
450
- label="Custom Split Words",
451
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
- lines=2,
453
- )
454
- speed_slider = gr.Slider(
455
- label="Speed",
456
- minimum=0.3,
457
- maximum=2.0,
458
- value=speed,
459
- step=0.1,
460
- info="Adjust the speed of the audio.",
461
- )
462
- speed_slider.change(update_speed, inputs=speed_slider)
463
-
464
- audio_output = gr.Audio(label="Synthesized Audio")
465
- spectrogram_output = gr.Image(label="Spectrogram")
466
-
467
- generate_btn.click(
468
- infer,
469
- inputs=[
470
- ref_audio_input,
471
- ref_text_input,
472
- gen_text_input,
473
- model_choice,
474
- remove_silence,
475
- split_words_input,
476
- ],
477
- outputs=[audio_output, spectrogram_output],
478
- )
479
-
480
- with gr.Blocks() as app_podcast:
481
- gr.Markdown("# Podcast Generation")
482
- speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
- ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
- ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
-
486
- speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
- ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
- ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
-
490
- script_input = gr.Textbox(label="Podcast Script", lines=10,
491
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
-
493
- podcast_model_choice = gr.Radio(
494
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
- )
496
- podcast_remove_silence = gr.Checkbox(
497
- label="Remove Silences",
498
- value=True,
499
- )
500
- generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
- podcast_output = gr.Audio(label="Generated Podcast")
502
-
503
- def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
- return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
-
506
- generate_podcast_btn.click(
507
- podcast_generation,
508
- inputs=[
509
- script_input,
510
- speaker1_name,
511
- ref_audio_input1,
512
- ref_text_input1,
513
- speaker2_name,
514
- ref_audio_input2,
515
- ref_text_input2,
516
- podcast_model_choice,
517
- podcast_remove_silence,
518
- ],
519
- outputs=podcast_output,
520
- )
521
-
522
- def parse_emotional_text(gen_text):
523
- # Pattern to find (Emotion)
524
- pattern = r'\((.*?)\)'
525
-
526
- # Split the text by the pattern
527
- tokens = re.split(pattern, gen_text)
528
-
529
- segments = []
530
-
531
- current_emotion = 'Regular'
532
-
533
- for i in range(len(tokens)):
534
- if i % 2 == 0:
535
- # This is text
536
- text = tokens[i].strip()
537
- if text:
538
- segments.append({'emotion': current_emotion, 'text': text})
539
- else:
540
- # This is emotion
541
- emotion = tokens[i].strip()
542
- current_emotion = emotion
543
-
544
- return segments
545
-
546
- with gr.Blocks() as app_emotional:
547
- # New section for emotional generation
548
- gr.Markdown(
549
- """
550
- # Multiple Speech-Type Generation
551
-
552
- This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
-
554
- **Example Input:**
555
-
556
- (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
- """
558
- )
559
-
560
- gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
-
562
- # Regular speech type (mandatory)
563
- with gr.Row():
564
- regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
- regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
- regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
-
568
- # Additional speech types (up to 9 more)
569
- max_speech_types = 10
570
- speech_type_names = []
571
- speech_type_audios = []
572
- speech_type_ref_texts = []
573
- speech_type_delete_btns = []
574
-
575
- for i in range(max_speech_types - 1):
576
- with gr.Row():
577
- name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
- audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
- ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
- delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
- speech_type_names.append(name_input)
582
- speech_type_audios.append(audio_input)
583
- speech_type_ref_texts.append(ref_text_input)
584
- speech_type_delete_btns.append(delete_btn)
585
-
586
- # Button to add speech type
587
- add_speech_type_btn = gr.Button("Add Speech Type")
588
-
589
- # Keep track of current number of speech types
590
- speech_type_count = gr.State(value=0)
591
-
592
- # Function to add a speech type
593
- def add_speech_type_fn(speech_type_count):
594
- if speech_type_count < max_speech_types - 1:
595
- speech_type_count += 1
596
- # Prepare updates for the components
597
- name_updates = []
598
- audio_updates = []
599
- ref_text_updates = []
600
- delete_btn_updates = []
601
- for i in range(max_speech_types - 1):
602
- if i < speech_type_count:
603
- name_updates.append(gr.update(visible=True))
604
- audio_updates.append(gr.update(visible=True))
605
- ref_text_updates.append(gr.update(visible=True))
606
- delete_btn_updates.append(gr.update(visible=True))
607
- else:
608
- name_updates.append(gr.update())
609
- audio_updates.append(gr.update())
610
- ref_text_updates.append(gr.update())
611
- delete_btn_updates.append(gr.update())
612
- else:
613
- # Optionally, show a warning
614
- # gr.Warning("Maximum number of speech types reached.")
615
- name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
- audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
- ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
- delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
-
621
- add_speech_type_btn.click(
622
- add_speech_type_fn,
623
- inputs=speech_type_count,
624
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
- )
626
-
627
- # Function to delete a speech type
628
- def make_delete_speech_type_fn(index):
629
- def delete_speech_type_fn(speech_type_count):
630
- # Prepare updates
631
- name_updates = []
632
- audio_updates = []
633
- ref_text_updates = []
634
- delete_btn_updates = []
635
-
636
- for i in range(max_speech_types - 1):
637
- if i == index:
638
- name_updates.append(gr.update(visible=False, value=''))
639
- audio_updates.append(gr.update(visible=False, value=None))
640
- ref_text_updates.append(gr.update(visible=False, value=''))
641
- delete_btn_updates.append(gr.update(visible=False))
642
- else:
643
- name_updates.append(gr.update())
644
- audio_updates.append(gr.update())
645
- ref_text_updates.append(gr.update())
646
- delete_btn_updates.append(gr.update())
647
-
648
- speech_type_count = max(0, speech_type_count - 1)
649
-
650
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
-
652
- return delete_speech_type_fn
653
-
654
- for i, delete_btn in enumerate(speech_type_delete_btns):
655
- delete_fn = make_delete_speech_type_fn(i)
656
- delete_btn.click(
657
- delete_fn,
658
- inputs=speech_type_count,
659
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
- )
661
-
662
- # Text input for the prompt
663
- gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
-
665
- # Model choice
666
- model_choice_emotional = gr.Radio(
667
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
- )
669
-
670
- with gr.Accordion("Advanced Settings", open=False):
671
- remove_silence_emotional = gr.Checkbox(
672
- label="Remove Silences",
673
- value=True,
674
- )
675
-
676
- # Generate button
677
- generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
-
679
- # Output audio
680
- audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
-
682
- def generate_emotional_speech(
683
- regular_audio,
684
- regular_ref_text,
685
- gen_text,
686
- *args,
687
- ):
688
- num_additional_speech_types = max_speech_types - 1
689
- speech_type_names_list = args[:num_additional_speech_types]
690
- speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
- speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
- model_choice = args[3 * num_additional_speech_types]
693
- remove_silence = args[3 * num_additional_speech_types + 1]
694
-
695
- # Collect the speech types and their audios into a dict
696
- speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
-
698
- for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
- if name_input and audio_input:
700
- speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
-
702
- # Parse the gen_text into segments
703
- segments = parse_speechtypes_text(gen_text)
704
-
705
- # For each segment, generate speech
706
- generated_audio_segments = []
707
- current_emotion = 'Regular'
708
-
709
- for segment in segments:
710
- emotion = segment['emotion']
711
- text = segment['text']
712
-
713
- if emotion in speech_types:
714
- current_emotion = emotion
715
- else:
716
- # If emotion not available, default to Regular
717
- current_emotion = 'Regular'
718
-
719
- ref_audio = speech_types[current_emotion]['audio']
720
- ref_text = speech_types[current_emotion].get('ref_text', '')
721
-
722
- # Generate speech for this segment
723
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
- sr, audio_data = audio
725
-
726
- generated_audio_segments.append(audio_data)
727
-
728
- # Concatenate all audio segments
729
- if generated_audio_segments:
730
- final_audio_data = np.concatenate(generated_audio_segments)
731
- return (sr, final_audio_data)
732
- else:
733
- gr.Warning("No audio generated.")
734
- return None
735
-
736
- generate_emotional_btn.click(
737
- generate_emotional_speech,
738
- inputs=[
739
- regular_audio,
740
- regular_ref_text,
741
- gen_text_input_emotional,
742
- ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
- model_choice_emotional,
744
- remove_silence_emotional,
745
- ],
746
- outputs=audio_output_emotional,
747
- )
748
-
749
- # Validation function to disable Generate button if speech types are missing
750
- def validate_speech_types(
751
- gen_text,
752
- regular_name,
753
- *args
754
- ):
755
- num_additional_speech_types = max_speech_types - 1
756
- speech_type_names_list = args[:num_additional_speech_types]
757
-
758
- # Collect the speech types names
759
- speech_types_available = set()
760
- if regular_name:
761
- speech_types_available.add(regular_name)
762
- for name_input in speech_type_names_list:
763
- if name_input:
764
- speech_types_available.add(name_input)
765
-
766
- # Parse the gen_text to get the speech types used
767
- segments = parse_emotional_text(gen_text)
768
- speech_types_in_text = set(segment['emotion'] for segment in segments)
769
-
770
- # Check if all speech types in text are available
771
- missing_speech_types = speech_types_in_text - speech_types_available
772
-
773
- if missing_speech_types:
774
- # Disable the generate button
775
- return gr.update(interactive=False)
776
- else:
777
- # Enable the generate button
778
- return gr.update(interactive=True)
779
-
780
- gen_text_input_emotional.change(
781
- validate_speech_types,
782
- inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
- outputs=generate_emotional_btn
784
- )
785
- with gr.Blocks() as app:
786
- gr.Markdown(
787
- """
788
- # E2/F5 TTS
789
-
790
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
-
792
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
-
795
- The checkpoints support English and Chinese.
796
-
797
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
-
799
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
- """
801
- )
802
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
-
804
- @click.command()
805
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
- @click.option("--host", "-H", default=None, help="Host to run the app on")
807
- @click.option(
808
- "--share",
809
- "-s",
810
- default=False,
811
- is_flag=True,
812
- help="Share the app via Gradio share link",
813
- )
814
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
- def main(port, host, share, api):
816
- global app
817
- print(f"Starting app...")
818
- app.queue(api_open=api).launch(
819
- server_name=host, server_port=port, share=share, show_api=api
820
- )
821
-
822
-
823
- if __name__ == "__main__":
824
- main()