uto1125 commited on
Commit
49d537b
·
verified ·
1 Parent(s): a75fb04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +653 -653
app.py CHANGED
@@ -1,653 +1,653 @@
1
- import os
2
- import queue
3
- from huggingface_hub import snapshot_download
4
- import hydra
5
- import numpy as np
6
- import wave
7
- import io
8
- import pyrootutils
9
- import gc
10
-
11
- # Download if not exists
12
- os.makedirs("checkpoints", exist_ok=True)
13
- snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
14
-
15
- print("All checkpoints downloaded")
16
-
17
- import html
18
- import os
19
- import threading
20
- from argparse import ArgumentParser
21
- from pathlib import Path
22
- from functools import partial
23
-
24
- import gradio as gr
25
- import librosa
26
- import torch
27
- import torchaudio
28
-
29
- torchaudio.set_audio_backend("soundfile")
30
-
31
- from loguru import logger
32
- from transformers import AutoTokenizer
33
-
34
- from tools.llama.generate import launch_thread_safe_queue
35
- from tools.vqgan.inference import load_model as load_vqgan_model
36
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
37
- from tools.api import decode_vq_tokens, encode_reference
38
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
39
- from tools.llama.generate import (
40
- GenerateRequest,
41
- GenerateResponse,
42
- WrappedGenerateResponse,
43
- launch_thread_safe_queue,
44
- )
45
- from tools.vqgan.inference import load_model as load_decoder_model
46
-
47
- # Make einx happy
48
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
49
-
50
-
51
- HEADER_MD = """# Fish Speech
52
-
53
- ## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
54
- ## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
55
-
56
- A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
57
- 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
58
-
59
- You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).
60
- 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.4) 找到模型.
61
-
62
- Related code and weights are released under CC BY-NC-SA 4.0 License.
63
- 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
64
-
65
- We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
66
- 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
67
-
68
- The model running in this WebUI is Fish Speech V1.4 Medium.
69
- 在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
70
- """
71
-
72
- TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
73
-
74
- try:
75
- import spaces
76
-
77
- GPU_DECORATOR = spaces.GPU
78
- except ImportError:
79
-
80
- def GPU_DECORATOR(func):
81
- def wrapper(*args, **kwargs):
82
- return func(*args, **kwargs)
83
-
84
- return wrapper
85
-
86
-
87
- def build_html_error_message(error):
88
- return f"""
89
- <div style="color: red;
90
- font-weight: bold;">
91
- {html.escape(error)}
92
- </div>
93
- """
94
-
95
-
96
- @GPU_DECORATOR
97
- @torch.inference_mode()
98
- def inference(
99
- text,
100
- enable_reference_audio,
101
- reference_audio,
102
- reference_text,
103
- max_new_tokens,
104
- chunk_length,
105
- top_p,
106
- repetition_penalty,
107
- temperature,
108
- streaming=False
109
- ):
110
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
111
- return (
112
- None,
113
- None,
114
- "Text is too long, please keep it under {} characters.".format(
115
- args.max_gradio_length
116
- ),
117
- )
118
-
119
- # Parse reference audio aka prompt
120
- prompt_tokens = encode_reference(
121
- decoder_model=decoder_model,
122
- reference_audio=reference_audio,
123
- enable_reference_audio=enable_reference_audio,
124
- )
125
-
126
- # LLAMA Inference
127
- request = dict(
128
- device=decoder_model.device,
129
- max_new_tokens=max_new_tokens,
130
- text=text,
131
- top_p=top_p,
132
- repetition_penalty=repetition_penalty,
133
- temperature=temperature,
134
- compile=args.compile,
135
- iterative_prompt=chunk_length > 0,
136
- chunk_length=chunk_length,
137
- max_length=2048,
138
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
139
- prompt_text=reference_text if enable_reference_audio else None,
140
- )
141
-
142
- response_queue = queue.Queue()
143
- llama_queue.put(
144
- GenerateRequest(
145
- request=request,
146
- response_queue=response_queue,
147
- )
148
- )
149
-
150
- segments = []
151
-
152
- while True:
153
- result: WrappedGenerateResponse = response_queue.get()
154
- if result.status == "error":
155
- return None, None, build_html_error_message(result.response)
156
-
157
- result: GenerateResponse = result.response
158
- if result.action == "next":
159
- break
160
-
161
- with torch.autocast(
162
- device_type=(
163
- "cpu"
164
- if decoder_model.device.type == "mps"
165
- else decoder_model.device.type
166
- ),
167
- dtype=args.precision,
168
- ):
169
- fake_audios = decode_vq_tokens(
170
- decoder_model=decoder_model,
171
- codes=result.codes,
172
- )
173
-
174
- fake_audios = fake_audios.float().cpu().numpy()
175
- segments.append(fake_audios)
176
-
177
- if len(segments) == 0:
178
- return (
179
- None,
180
- None,
181
- build_html_error_message(
182
- "No audio generated, please check the input text."
183
- ),
184
- )
185
-
186
- # Return the final audio
187
- audio = np.concatenate(segments, axis=0)
188
- return None, (decoder_model.spec_transform.sample_rate, audio), None
189
-
190
- if torch.cpu.is_available():
191
- torch.cpu.empty_cache()
192
- gc.collect()
193
-
194
-
195
- def inference_with_auto_rerank(
196
- text,
197
- enable_reference_audio,
198
- reference_audio,
199
- reference_text,
200
- max_new_tokens,
201
- chunk_length,
202
- top_p,
203
- repetition_penalty,
204
- temperature,
205
- use_auto_rerank,
206
- streaming=False,
207
- ):
208
- max_attempts = 2 if use_auto_rerank else 1
209
- best_wer = float("inf")
210
- best_audio = None
211
- best_sample_rate = None
212
-
213
- for attempt in range(max_attempts):
214
- _, (sample_rate, audio), message = inference(
215
- text,
216
- enable_reference_audio,
217
- reference_audio,
218
- reference_text,
219
- max_new_tokens,
220
- chunk_length,
221
- top_p,
222
- repetition_penalty,
223
- temperature,
224
- streaming=False,
225
- )
226
-
227
- if audio is None:
228
- return None, None, message
229
-
230
- if not use_auto_rerank:
231
- return None, (sample_rate, audio), None
232
-
233
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
234
- wer = calculate_wer(text, asr_result["text"])
235
-
236
- if wer <= 0.3 and not asr_result["huge_gap"]:
237
- return None, (sample_rate, audio), None
238
-
239
- if wer < best_wer:
240
- best_wer = wer
241
- best_audio = audio
242
- best_sample_rate = sample_rate
243
-
244
- if attempt == max_attempts - 1:
245
- break
246
-
247
- return None, (best_sample_rate, best_audio), None
248
-
249
-
250
- n_audios = 4
251
-
252
- global_audio_list = []
253
- global_error_list = []
254
-
255
- def inference_wrapper(
256
- text,
257
- enable_reference_audio,
258
- reference_audio,
259
- reference_text,
260
- max_new_tokens,
261
- chunk_length,
262
- top_p,
263
- repetition_penalty,
264
- temperature,
265
- batch_infer_num,
266
- if_load_asr_model,
267
- ):
268
- audios = []
269
- errors = []
270
-
271
- for _ in range(batch_infer_num):
272
- result = inference_with_auto_rerank(
273
- text,
274
- enable_reference_audio,
275
- reference_audio,
276
- reference_text,
277
- max_new_tokens,
278
- chunk_length,
279
- top_p,
280
- repetition_penalty,
281
- temperature,
282
- if_load_asr_model,
283
- )
284
-
285
- _, audio_data, error_message = result
286
-
287
- audios.append(
288
- gr.Audio(value=audio_data if audio_data else None, visible=True),
289
- )
290
- errors.append(
291
- gr.HTML(value=error_message if error_message else None, visible=True),
292
- )
293
-
294
- for _ in range(batch_infer_num, n_audios):
295
- audios.append(
296
- gr.Audio(value=None, visible=False),
297
- )
298
- errors.append(
299
- gr.HTML(value=None, visible=False),
300
- )
301
-
302
- return None, *audios, *errors
303
-
304
-
305
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
306
- buffer = io.BytesIO()
307
-
308
- with wave.open(buffer, "wb") as wav_file:
309
- wav_file.setnchannels(channels)
310
- wav_file.setsampwidth(bit_depth // 8)
311
- wav_file.setframerate(sample_rate)
312
-
313
- wav_header_bytes = buffer.getvalue()
314
- buffer.close()
315
- return wav_header_bytes
316
-
317
-
318
- def normalize_text(user_input, use_normalization):
319
- if use_normalization:
320
- return ChnNormedText(raw_text=user_input).normalize()
321
- else:
322
- return user_input
323
-
324
-
325
- asr_model = None
326
-
327
-
328
- def change_if_load_asr_model(if_load):
329
- global asr_model
330
-
331
- if if_load:
332
- gr.Warning("Loading faster whisper model...")
333
- if asr_model is None:
334
- asr_model = load_model()
335
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
336
-
337
- if if_load is False:
338
- gr.Warning("Unloading faster whisper model...")
339
- del asr_model
340
- asr_model = None
341
- if torch.cpu.is_available():
342
- torch.cpu.empty_cache()
343
- gc.collect()
344
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
345
-
346
-
347
- def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
348
- if if_load and asr_model is not None:
349
- if (
350
- if_auto_label
351
- and enable_ref
352
- and ref_audio is not None
353
- and ref_text.strip() == ""
354
- ):
355
- data, sample_rate = librosa.load(ref_audio)
356
- res = batch_asr(asr_model, [data], sample_rate)[0]
357
- ref_text = res["text"]
358
- else:
359
- gr.Warning("Whisper model not loaded!")
360
-
361
- return gr.Textbox(value=ref_text)
362
-
363
-
364
- def build_app():
365
- with gr.Blocks(theme=gr.themes.Base()) as app:
366
- gr.Markdown(HEADER_MD)
367
-
368
- # Use light theme by default
369
- app.load(
370
- None,
371
- None,
372
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
373
- % args.theme,
374
- )
375
-
376
- # Inference
377
- with gr.Row():
378
- with gr.Column(scale=3):
379
- text = gr.Textbox(
380
- label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
381
- )
382
- refined_text = gr.Textbox(
383
- label="Realtime Transform Text",
384
- placeholder=
385
- "Normalization Result Preview (Currently Only Chinese)",
386
- lines=5,
387
- interactive=False,
388
- )
389
-
390
- with gr.Row():
391
- if_refine_text = gr.Checkbox(
392
- label="Text Normalization (ZH)",
393
- value=False,
394
- scale=1,
395
- )
396
-
397
- if_load_asr_model = gr.Checkbox(
398
- label="Load / Unload ASR model for auto-reranking",
399
- value=False,
400
- scale=3,
401
- )
402
-
403
- with gr.Row():
404
- with gr.Tab(label="Advanced Config"):
405
- chunk_length = gr.Slider(
406
- label="Iterative Prompt Length, 0 means off",
407
- minimum=0,
408
- maximum=500,
409
- value=200,
410
- step=8,
411
- )
412
-
413
- max_new_tokens = gr.Slider(
414
- label="Maximum tokens per batch, 0 means no limit",
415
- minimum=0,
416
- maximum=2048,
417
- value=1024, # 0 means no limit
418
- step=8,
419
- )
420
-
421
- top_p = gr.Slider(
422
- label="Top-P",
423
- minimum=0.6,
424
- maximum=0.9,
425
- value=0.7,
426
- step=0.01,
427
- )
428
-
429
- repetition_penalty = gr.Slider(
430
- label="Repetition Penalty",
431
- minimum=1,
432
- maximum=1.5,
433
- value=1.2,
434
- step=0.01,
435
- )
436
-
437
- temperature = gr.Slider(
438
- label="Temperature",
439
- minimum=0.6,
440
- maximum=0.9,
441
- value=0.7,
442
- step=0.01,
443
- )
444
-
445
- with gr.Tab(label="Reference Audio"):
446
- gr.Markdown(
447
- "5 to 10 seconds of reference audio, useful for specifying speaker."
448
- )
449
-
450
- enable_reference_audio = gr.Checkbox(
451
- label="Enable Reference Audio",
452
- )
453
-
454
- # Add dropdown for selecting example audio files
455
- example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
456
- example_audio_dropdown = gr.Dropdown(
457
- label="Select Example Audio",
458
- choices=[""] + example_audio_files,
459
- value=""
460
- )
461
-
462
- reference_audio = gr.Audio(
463
- label="Reference Audio",
464
- type="filepath",
465
- )
466
- with gr.Row():
467
- if_auto_label = gr.Checkbox(
468
- label="Auto Labeling",
469
- min_width=100,
470
- scale=0,
471
- value=False,
472
- )
473
- reference_text = gr.Textbox(
474
- label="Reference Text",
475
- lines=1,
476
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
477
- value="",
478
- )
479
- with gr.Tab(label="Batch Inference"):
480
- batch_infer_num = gr.Slider(
481
- label="Batch infer nums",
482
- minimum=1,
483
- maximum=n_audios,
484
- step=1,
485
- value=1,
486
- )
487
-
488
- with gr.Column(scale=3):
489
- for _ in range(n_audios):
490
- with gr.Row():
491
- error = gr.HTML(
492
- label="Error Message",
493
- visible=True if _ == 0 else False,
494
- )
495
- global_error_list.append(error)
496
- with gr.Row():
497
- audio = gr.Audio(
498
- label="Generated Audio",
499
- type="numpy",
500
- interactive=False,
501
- visible=True if _ == 0 else False,
502
- )
503
- global_audio_list.append(audio)
504
-
505
- with gr.Row():
506
- stream_audio = gr.Audio(
507
- label="Streaming Audio",
508
- streaming=True,
509
- autoplay=True,
510
- interactive=False,
511
- show_download_button=True,
512
- )
513
- with gr.Row():
514
- with gr.Column(scale=3):
515
- generate = gr.Button(
516
- value="\U0001F3A7 " + "Generate", variant="primary"
517
- )
518
- generate_stream = gr.Button(
519
- value="\U0001F3A7 " + "Streaming Generate",
520
- variant="primary",
521
- )
522
-
523
- text.input(
524
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
525
- )
526
-
527
- if_load_asr_model.change(
528
- fn=change_if_load_asr_model,
529
- inputs=[if_load_asr_model],
530
- outputs=[if_load_asr_model],
531
- )
532
-
533
- if_auto_label.change(
534
- fn=lambda: gr.Textbox(value=""),
535
- inputs=[],
536
- outputs=[reference_text],
537
- ).then(
538
- fn=change_if_auto_label,
539
- inputs=[
540
- if_load_asr_model,
541
- if_auto_label,
542
- enable_reference_audio,
543
- reference_audio,
544
- reference_text,
545
- ],
546
- outputs=[reference_text],
547
- )
548
-
549
- def select_example_audio(audio_file):
550
- if audio_file:
551
- audio_path = os.path.join("examples", audio_file)
552
- lab_file = os.path.splitext(audio_file)[0] + ".lab"
553
- lab_path = os.path.join("examples", lab_file)
554
-
555
- if os.path.exists(lab_path):
556
- with open(lab_path, "r", encoding="utf-8") as f:
557
- lab_content = f.read().strip()
558
- else:
559
- lab_content = ""
560
-
561
- return audio_path, lab_content, True
562
- return None, "", False
563
-
564
- # Connect the dropdown to update reference audio and text
565
- example_audio_dropdown.change(
566
- fn=select_example_audio,
567
- inputs=[example_audio_dropdown],
568
- outputs=[reference_audio, reference_text, enable_reference_audio]
569
- )
570
- # # Submit
571
- generate.click(
572
- inference_wrapper,
573
- [
574
- refined_text,
575
- enable_reference_audio,
576
- reference_audio,
577
- reference_text,
578
- max_new_tokens,
579
- chunk_length,
580
- top_p,
581
- repetition_penalty,
582
- temperature,
583
- batch_infer_num,
584
- if_load_asr_model,
585
- ],
586
- [stream_audio, *global_audio_list, *global_error_list],
587
- concurrency_limit=1,
588
- )
589
- return app
590
-
591
-
592
- def parse_args():
593
- parser = ArgumentParser()
594
- parser.add_argument(
595
- "--llama-checkpoint-path",
596
- type=Path,
597
- default="checkpoints/fish-speech-1.4",
598
- )
599
- parser.add_argument(
600
- "--decoder-checkpoint-path",
601
- type=Path,
602
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
603
- )
604
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
605
- parser.add_argument("--device", type=str, default="cuda")
606
- parser.add_argument("--half", action="store_true")
607
- parser.add_argument("--compile", action="store_true",default=True)
608
- parser.add_argument("--max-gradio-length", type=int, default=0)
609
- parser.add_argument("--theme", type=str, default="light")
610
-
611
- return parser.parse_args()
612
-
613
-
614
- if __name__ == "__main__":
615
- args = parse_args()
616
- args.precision = torch.half if args.half else torch.bfloat16
617
-
618
- logger.info("Loading Llama model...")
619
- llama_queue = launch_thread_safe_queue(
620
- checkpoint_path=args.llama_checkpoint_path,
621
- device=args.device,
622
- precision=args.precision,
623
- compile=args.compile,
624
- )
625
- logger.info("Llama model loaded, loading VQ-GAN model...")
626
-
627
- decoder_model = load_decoder_model(
628
- config_name=args.decoder_config_name,
629
- checkpoint_path=args.decoder_checkpoint_path,
630
- device=args.device,
631
- )
632
-
633
- logger.info("Decoder model loaded, warming up...")
634
-
635
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
636
- list(
637
- inference(
638
- text="Hello, world!",
639
- enable_reference_audio=False,
640
- reference_audio=None,
641
- reference_text="",
642
- max_new_tokens=0,
643
- chunk_length=100,
644
- top_p=0.7,
645
- repetition_penalty=1.2,
646
- temperature=0.7,
647
- )
648
- )
649
-
650
- logger.info("Warming up done, launching the web UI...")
651
-
652
- app = build_app()
653
- app.launch(show_api=True)
 
1
+ import os
2
+ import queue
3
+ from huggingface_hub import snapshot_download
4
+ import hydra
5
+ import numpy as np
6
+ import wave
7
+ import io
8
+ import pyrootutils
9
+ import gc
10
+
11
+ # Download if not exists
12
+ os.makedirs("checkpoints", exist_ok=True)
13
+ snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
14
+
15
+ print("All checkpoints downloaded")
16
+ torch.cuda.is_available = lambda: False
17
+ import html
18
+ import os
19
+ import threading
20
+ from argparse import ArgumentParser
21
+ from pathlib import Path
22
+ from functools import partial
23
+
24
+ import gradio as gr
25
+ import librosa
26
+ import torch
27
+ import torchaudio
28
+
29
+ torchaudio.set_audio_backend("soundfile")
30
+
31
+ from loguru import logger
32
+ from transformers import AutoTokenizer
33
+
34
+ from tools.llama.generate import launch_thread_safe_queue
35
+ from tools.vqgan.inference import load_model as load_vqgan_model
36
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
37
+ from tools.api import decode_vq_tokens, encode_reference
38
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
39
+ from tools.llama.generate import (
40
+ GenerateRequest,
41
+ GenerateResponse,
42
+ WrappedGenerateResponse,
43
+ launch_thread_safe_queue,
44
+ )
45
+ from tools.vqgan.inference import load_model as load_decoder_model
46
+
47
+ # Make einx happy
48
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
49
+
50
+
51
+ HEADER_MD = """# Fish Speech
52
+
53
+ ## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
54
+ ## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
55
+
56
+ A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
57
+ 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
58
+
59
+ You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).
60
+ 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.4) 找到模型.
61
+
62
+ Related code and weights are released under CC BY-NC-SA 4.0 License.
63
+ 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
64
+
65
+ We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
66
+ 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
67
+
68
+ The model running in this WebUI is Fish Speech V1.4 Medium.
69
+ 在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
70
+ """
71
+
72
+ TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
73
+
74
+ try:
75
+ import spaces
76
+
77
+ GPU_DECORATOR = spaces.GPU
78
+ except ImportError:
79
+
80
+ def GPU_DECORATOR(func):
81
+ def wrapper(*args, **kwargs):
82
+ return func(*args, **kwargs)
83
+
84
+ return wrapper
85
+
86
+
87
+ def build_html_error_message(error):
88
+ return f"""
89
+ <div style="color: red;
90
+ font-weight: bold;">
91
+ {html.escape(error)}
92
+ </div>
93
+ """
94
+
95
+
96
+ @GPU_DECORATOR
97
+ @torch.inference_mode()
98
+ def inference(
99
+ text,
100
+ enable_reference_audio,
101
+ reference_audio,
102
+ reference_text,
103
+ max_new_tokens,
104
+ chunk_length,
105
+ top_p,
106
+ repetition_penalty,
107
+ temperature,
108
+ streaming=False
109
+ ):
110
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
111
+ return (
112
+ None,
113
+ None,
114
+ "Text is too long, please keep it under {} characters.".format(
115
+ args.max_gradio_length
116
+ ),
117
+ )
118
+
119
+ # Parse reference audio aka prompt
120
+ prompt_tokens = encode_reference(
121
+ decoder_model=decoder_model,
122
+ reference_audio=reference_audio,
123
+ enable_reference_audio=enable_reference_audio,
124
+ )
125
+
126
+ # LLAMA Inference
127
+ request = dict(
128
+ device=decoder_model.device,
129
+ max_new_tokens=max_new_tokens,
130
+ text=text,
131
+ top_p=top_p,
132
+ repetition_penalty=repetition_penalty,
133
+ temperature=temperature,
134
+ compile=args.compile,
135
+ iterative_prompt=chunk_length > 0,
136
+ chunk_length=chunk_length,
137
+ max_length=2048,
138
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
139
+ prompt_text=reference_text if enable_reference_audio else None,
140
+ )
141
+
142
+ response_queue = queue.Queue()
143
+ llama_queue.put(
144
+ GenerateRequest(
145
+ request=request,
146
+ response_queue=response_queue,
147
+ )
148
+ )
149
+
150
+ segments = []
151
+
152
+ while True:
153
+ result: WrappedGenerateResponse = response_queue.get()
154
+ if result.status == "error":
155
+ return None, None, build_html_error_message(result.response)
156
+
157
+ result: GenerateResponse = result.response
158
+ if result.action == "next":
159
+ break
160
+
161
+ with torch.autocast(
162
+ device_type=(
163
+ "cpu"
164
+ if decoder_model.device.type == "mps"
165
+ else decoder_model.device.type
166
+ ),
167
+ dtype=args.precision,
168
+ ):
169
+ fake_audios = decode_vq_tokens(
170
+ decoder_model=decoder_model,
171
+ codes=result.codes,
172
+ )
173
+
174
+ fake_audios = fake_audios.float().cpu().numpy()
175
+ segments.append(fake_audios)
176
+
177
+ if len(segments) == 0:
178
+ return (
179
+ None,
180
+ None,
181
+ build_html_error_message(
182
+ "No audio generated, please check the input text."
183
+ ),
184
+ )
185
+
186
+ # Return the final audio
187
+ audio = np.concatenate(segments, axis=0)
188
+ return None, (decoder_model.spec_transform.sample_rate, audio), None
189
+
190
+ if torch.cpu.is_available():
191
+ torch.cpu.empty_cache()
192
+ gc.collect()
193
+
194
+
195
+ def inference_with_auto_rerank(
196
+ text,
197
+ enable_reference_audio,
198
+ reference_audio,
199
+ reference_text,
200
+ max_new_tokens,
201
+ chunk_length,
202
+ top_p,
203
+ repetition_penalty,
204
+ temperature,
205
+ use_auto_rerank,
206
+ streaming=False,
207
+ ):
208
+ max_attempts = 2 if use_auto_rerank else 1
209
+ best_wer = float("inf")
210
+ best_audio = None
211
+ best_sample_rate = None
212
+
213
+ for attempt in range(max_attempts):
214
+ _, (sample_rate, audio), message = inference(
215
+ text,
216
+ enable_reference_audio,
217
+ reference_audio,
218
+ reference_text,
219
+ max_new_tokens,
220
+ chunk_length,
221
+ top_p,
222
+ repetition_penalty,
223
+ temperature,
224
+ streaming=False,
225
+ )
226
+
227
+ if audio is None:
228
+ return None, None, message
229
+
230
+ if not use_auto_rerank:
231
+ return None, (sample_rate, audio), None
232
+
233
+ asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
234
+ wer = calculate_wer(text, asr_result["text"])
235
+
236
+ if wer <= 0.3 and not asr_result["huge_gap"]:
237
+ return None, (sample_rate, audio), None
238
+
239
+ if wer < best_wer:
240
+ best_wer = wer
241
+ best_audio = audio
242
+ best_sample_rate = sample_rate
243
+
244
+ if attempt == max_attempts - 1:
245
+ break
246
+
247
+ return None, (best_sample_rate, best_audio), None
248
+
249
+
250
+ n_audios = 4
251
+
252
+ global_audio_list = []
253
+ global_error_list = []
254
+
255
+ def inference_wrapper(
256
+ text,
257
+ enable_reference_audio,
258
+ reference_audio,
259
+ reference_text,
260
+ max_new_tokens,
261
+ chunk_length,
262
+ top_p,
263
+ repetition_penalty,
264
+ temperature,
265
+ batch_infer_num,
266
+ if_load_asr_model,
267
+ ):
268
+ audios = []
269
+ errors = []
270
+
271
+ for _ in range(batch_infer_num):
272
+ result = inference_with_auto_rerank(
273
+ text,
274
+ enable_reference_audio,
275
+ reference_audio,
276
+ reference_text,
277
+ max_new_tokens,
278
+ chunk_length,
279
+ top_p,
280
+ repetition_penalty,
281
+ temperature,
282
+ if_load_asr_model,
283
+ )
284
+
285
+ _, audio_data, error_message = result
286
+
287
+ audios.append(
288
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
289
+ )
290
+ errors.append(
291
+ gr.HTML(value=error_message if error_message else None, visible=True),
292
+ )
293
+
294
+ for _ in range(batch_infer_num, n_audios):
295
+ audios.append(
296
+ gr.Audio(value=None, visible=False),
297
+ )
298
+ errors.append(
299
+ gr.HTML(value=None, visible=False),
300
+ )
301
+
302
+ return None, *audios, *errors
303
+
304
+
305
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
306
+ buffer = io.BytesIO()
307
+
308
+ with wave.open(buffer, "wb") as wav_file:
309
+ wav_file.setnchannels(channels)
310
+ wav_file.setsampwidth(bit_depth // 8)
311
+ wav_file.setframerate(sample_rate)
312
+
313
+ wav_header_bytes = buffer.getvalue()
314
+ buffer.close()
315
+ return wav_header_bytes
316
+
317
+
318
+ def normalize_text(user_input, use_normalization):
319
+ if use_normalization:
320
+ return ChnNormedText(raw_text=user_input).normalize()
321
+ else:
322
+ return user_input
323
+
324
+
325
+ asr_model = None
326
+
327
+
328
+ def change_if_load_asr_model(if_load):
329
+ global asr_model
330
+
331
+ if if_load:
332
+ gr.Warning("Loading faster whisper model...")
333
+ if asr_model is None:
334
+ asr_model = load_model()
335
+ return gr.Checkbox(label="Unload faster whisper model", value=if_load)
336
+
337
+ if if_load is False:
338
+ gr.Warning("Unloading faster whisper model...")
339
+ del asr_model
340
+ asr_model = None
341
+ if torch.cpu.is_available():
342
+ torch.cpu.empty_cache()
343
+ gc.collect()
344
+ return gr.Checkbox(label="Load faster whisper model", value=if_load)
345
+
346
+
347
+ def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
348
+ if if_load and asr_model is not None:
349
+ if (
350
+ if_auto_label
351
+ and enable_ref
352
+ and ref_audio is not None
353
+ and ref_text.strip() == ""
354
+ ):
355
+ data, sample_rate = librosa.load(ref_audio)
356
+ res = batch_asr(asr_model, [data], sample_rate)[0]
357
+ ref_text = res["text"]
358
+ else:
359
+ gr.Warning("Whisper model not loaded!")
360
+
361
+ return gr.Textbox(value=ref_text)
362
+
363
+
364
+ def build_app():
365
+ with gr.Blocks(theme=gr.themes.Base()) as app:
366
+ gr.Markdown(HEADER_MD)
367
+
368
+ # Use light theme by default
369
+ app.load(
370
+ None,
371
+ None,
372
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
373
+ % args.theme,
374
+ )
375
+
376
+ # Inference
377
+ with gr.Row():
378
+ with gr.Column(scale=3):
379
+ text = gr.Textbox(
380
+ label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
381
+ )
382
+ refined_text = gr.Textbox(
383
+ label="Realtime Transform Text",
384
+ placeholder=
385
+ "Normalization Result Preview (Currently Only Chinese)",
386
+ lines=5,
387
+ interactive=False,
388
+ )
389
+
390
+ with gr.Row():
391
+ if_refine_text = gr.Checkbox(
392
+ label="Text Normalization (ZH)",
393
+ value=False,
394
+ scale=1,
395
+ )
396
+
397
+ if_load_asr_model = gr.Checkbox(
398
+ label="Load / Unload ASR model for auto-reranking",
399
+ value=False,
400
+ scale=3,
401
+ )
402
+
403
+ with gr.Row():
404
+ with gr.Tab(label="Advanced Config"):
405
+ chunk_length = gr.Slider(
406
+ label="Iterative Prompt Length, 0 means off",
407
+ minimum=0,
408
+ maximum=500,
409
+ value=200,
410
+ step=8,
411
+ )
412
+
413
+ max_new_tokens = gr.Slider(
414
+ label="Maximum tokens per batch, 0 means no limit",
415
+ minimum=0,
416
+ maximum=2048,
417
+ value=1024, # 0 means no limit
418
+ step=8,
419
+ )
420
+
421
+ top_p = gr.Slider(
422
+ label="Top-P",
423
+ minimum=0.6,
424
+ maximum=0.9,
425
+ value=0.7,
426
+ step=0.01,
427
+ )
428
+
429
+ repetition_penalty = gr.Slider(
430
+ label="Repetition Penalty",
431
+ minimum=1,
432
+ maximum=1.5,
433
+ value=1.2,
434
+ step=0.01,
435
+ )
436
+
437
+ temperature = gr.Slider(
438
+ label="Temperature",
439
+ minimum=0.6,
440
+ maximum=0.9,
441
+ value=0.7,
442
+ step=0.01,
443
+ )
444
+
445
+ with gr.Tab(label="Reference Audio"):
446
+ gr.Markdown(
447
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
448
+ )
449
+
450
+ enable_reference_audio = gr.Checkbox(
451
+ label="Enable Reference Audio",
452
+ )
453
+
454
+ # Add dropdown for selecting example audio files
455
+ example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
456
+ example_audio_dropdown = gr.Dropdown(
457
+ label="Select Example Audio",
458
+ choices=[""] + example_audio_files,
459
+ value=""
460
+ )
461
+
462
+ reference_audio = gr.Audio(
463
+ label="Reference Audio",
464
+ type="filepath",
465
+ )
466
+ with gr.Row():
467
+ if_auto_label = gr.Checkbox(
468
+ label="Auto Labeling",
469
+ min_width=100,
470
+ scale=0,
471
+ value=False,
472
+ )
473
+ reference_text = gr.Textbox(
474
+ label="Reference Text",
475
+ lines=1,
476
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
477
+ value="",
478
+ )
479
+ with gr.Tab(label="Batch Inference"):
480
+ batch_infer_num = gr.Slider(
481
+ label="Batch infer nums",
482
+ minimum=1,
483
+ maximum=n_audios,
484
+ step=1,
485
+ value=1,
486
+ )
487
+
488
+ with gr.Column(scale=3):
489
+ for _ in range(n_audios):
490
+ with gr.Row():
491
+ error = gr.HTML(
492
+ label="Error Message",
493
+ visible=True if _ == 0 else False,
494
+ )
495
+ global_error_list.append(error)
496
+ with gr.Row():
497
+ audio = gr.Audio(
498
+ label="Generated Audio",
499
+ type="numpy",
500
+ interactive=False,
501
+ visible=True if _ == 0 else False,
502
+ )
503
+ global_audio_list.append(audio)
504
+
505
+ with gr.Row():
506
+ stream_audio = gr.Audio(
507
+ label="Streaming Audio",
508
+ streaming=True,
509
+ autoplay=True,
510
+ interactive=False,
511
+ show_download_button=True,
512
+ )
513
+ with gr.Row():
514
+ with gr.Column(scale=3):
515
+ generate = gr.Button(
516
+ value="\U0001F3A7 " + "Generate", variant="primary"
517
+ )
518
+ generate_stream = gr.Button(
519
+ value="\U0001F3A7 " + "Streaming Generate",
520
+ variant="primary",
521
+ )
522
+
523
+ text.input(
524
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
525
+ )
526
+
527
+ if_load_asr_model.change(
528
+ fn=change_if_load_asr_model,
529
+ inputs=[if_load_asr_model],
530
+ outputs=[if_load_asr_model],
531
+ )
532
+
533
+ if_auto_label.change(
534
+ fn=lambda: gr.Textbox(value=""),
535
+ inputs=[],
536
+ outputs=[reference_text],
537
+ ).then(
538
+ fn=change_if_auto_label,
539
+ inputs=[
540
+ if_load_asr_model,
541
+ if_auto_label,
542
+ enable_reference_audio,
543
+ reference_audio,
544
+ reference_text,
545
+ ],
546
+ outputs=[reference_text],
547
+ )
548
+
549
+ def select_example_audio(audio_file):
550
+ if audio_file:
551
+ audio_path = os.path.join("examples", audio_file)
552
+ lab_file = os.path.splitext(audio_file)[0] + ".lab"
553
+ lab_path = os.path.join("examples", lab_file)
554
+
555
+ if os.path.exists(lab_path):
556
+ with open(lab_path, "r", encoding="utf-8") as f:
557
+ lab_content = f.read().strip()
558
+ else:
559
+ lab_content = ""
560
+
561
+ return audio_path, lab_content, True
562
+ return None, "", False
563
+
564
+ # Connect the dropdown to update reference audio and text
565
+ example_audio_dropdown.change(
566
+ fn=select_example_audio,
567
+ inputs=[example_audio_dropdown],
568
+ outputs=[reference_audio, reference_text, enable_reference_audio]
569
+ )
570
+ # # Submit
571
+ generate.click(
572
+ inference_wrapper,
573
+ [
574
+ refined_text,
575
+ enable_reference_audio,
576
+ reference_audio,
577
+ reference_text,
578
+ max_new_tokens,
579
+ chunk_length,
580
+ top_p,
581
+ repetition_penalty,
582
+ temperature,
583
+ batch_infer_num,
584
+ if_load_asr_model,
585
+ ],
586
+ [stream_audio, *global_audio_list, *global_error_list],
587
+ concurrency_limit=1,
588
+ )
589
+ return app
590
+
591
+
592
+ def parse_args():
593
+ parser = ArgumentParser()
594
+ parser.add_argument(
595
+ "--llama-checkpoint-path",
596
+ type=Path,
597
+ default="checkpoints/fish-speech-1.4",
598
+ )
599
+ parser.add_argument(
600
+ "--decoder-checkpoint-path",
601
+ type=Path,
602
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
603
+ )
604
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
605
+ parser.add_argument("--device", type=str, default="cuda")
606
+ parser.add_argument("--half", action="store_true")
607
+ parser.add_argument("--compile", action="store_true",default=True)
608
+ parser.add_argument("--max-gradio-length", type=int, default=0)
609
+ parser.add_argument("--theme", type=str, default="light")
610
+
611
+ return parser.parse_args()
612
+
613
+
614
+ if __name__ == "__main__":
615
+ args = parse_args()
616
+ args.precision = torch.half if args.half else torch.bfloat16
617
+
618
+ logger.info("Loading Llama model...")
619
+ llama_queue = launch_thread_safe_queue(
620
+ checkpoint_path=args.llama_checkpoint_path,
621
+ device=args.device,
622
+ precision=args.precision,
623
+ compile=args.compile,
624
+ )
625
+ logger.info("Llama model loaded, loading VQ-GAN model...")
626
+
627
+ decoder_model = load_decoder_model(
628
+ config_name=args.decoder_config_name,
629
+ checkpoint_path=args.decoder_checkpoint_path,
630
+ device=args.device,
631
+ )
632
+
633
+ logger.info("Decoder model loaded, warming up...")
634
+
635
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
636
+ list(
637
+ inference(
638
+ text="Hello, world!",
639
+ enable_reference_audio=False,
640
+ reference_audio=None,
641
+ reference_text="",
642
+ max_new_tokens=0,
643
+ chunk_length=100,
644
+ top_p=0.7,
645
+ repetition_penalty=1.2,
646
+ temperature=0.7,
647
+ )
648
+ )
649
+
650
+ logger.info("Warming up done, launching the web UI...")
651
+
652
+ app = build_app()
653
+ app.launch(show_api=True)