uto1125 commited on
Commit
74272e7
·
verified ·
1 Parent(s): f9a6c91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -516
app.py CHANGED
@@ -1,520 +1,28 @@
1
- import gc
2
- import html
3
- import io
4
- import os
5
- import queue
6
- import wave
7
- from argparse import ArgumentParser
8
- from functools import partial
9
- from pathlib import Path
10
-
11
- import gradio as gr
12
- import librosa
13
- import numpy as np
14
- import pyrootutils
15
- import torch
16
- from loguru import logger
17
- from transformers import AutoTokenizer
18
-
19
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
-
21
- from fish_speech.i18n import i18n
22
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
23
- from fish_speech.utils import autocast_exclude_mps
24
- from tools.api import decode_vq_tokens, encode_reference
25
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
26
- from tools.llama.generate import (
27
- GenerateRequest,
28
- GenerateResponse,
29
- WrappedGenerateResponse,
30
- launch_thread_safe_queue,
31
- )
32
- from tools.vqgan.inference import load_model as load_decoder_model
33
-
34
- # Make einx happy
35
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
36
-
37
- HEADER_MD = f"""# Fish Speech
38
-
39
- {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
40
-
41
- {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
42
-
43
- {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
44
-
45
- {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
46
- """
47
-
48
- TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
49
- SPACE_IMPORTED = False
50
-
51
- # 定义参数变量
52
- llama_checkpoint_path = "checkpoints/fish-speech-1.4-sft-yth-lora"
53
- decoder_checkpoint_path = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
54
- decoder_config_name = "firefly_gan_vq"
55
- device = "cpu"
56
-
57
-
58
- def build_html_error_message(error):
59
- return f"""
60
- <div style="color: red;
61
- font-weight: bold;">
62
- {html.escape(str(error))}
63
- </div>
64
- """
65
-
66
-
67
- @torch.inference_mode()
68
- def inference(
69
- text,
70
- enable_reference_audio,
71
- reference_audio,
72
- reference_text,
73
- max_new_tokens,
74
- chunk_length,
75
- top_p,
76
- repetition_penalty,
77
- temperature,
78
- streaming=False,
79
- ):
80
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
81
- return (
82
- None,
83
- None,
84
- i18n("Text is too long, please keep it under {} characters.").format(
85
- args.max_gradio_length
86
- ),
87
- )
88
-
89
- # Parse reference audio aka prompt
90
- prompt_tokens = encode_reference(
91
- decoder_model=decoder_model,
92
- reference_audio=reference_audio,
93
- enable_reference_audio=enable_reference_audio,
94
- )
95
-
96
- # LLAMA Inference
97
- request = dict(
98
- device=device, # 使用指定的设备
99
- max_new_tokens=max_new_tokens,
100
- text=text,
101
- top_p=top_p,
102
- repetition_penalty=repetition_penalty,
103
- temperature=temperature,
104
- compile=args.compile,
105
- iterative_prompt=chunk_length > 0,
106
- chunk_length=chunk_length,
107
- max_length=2048,
108
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
109
- prompt_text=reference_text if enable_reference_audio else None,
110
- )
111
-
112
- response_queue = queue.Queue()
113
- llama_queue.put(
114
- GenerateRequest(
115
- request=request,
116
- response_queue=response_queue,
117
- )
118
- )
119
-
120
- if streaming:
121
- yield wav_chunk_header(), None, None
122
-
123
- segments = []
124
-
125
- while True:
126
- result: WrappedGenerateResponse = response_queue.get()
127
- if result.status == "error":
128
- yield None, None, build_html_error_message(result.response)
129
- break
130
-
131
- result: GenerateResponse = result.response
132
- if result.action == "next":
133
- break
134
-
135
- with autocast_exclude_mps(device_type=device, dtype=args.precision): # 使用指定的设备
136
- fake_audios = decode_vq_tokens(
137
- decoder_model=decoder_model,
138
- codes=result.codes,
139
- )
140
-
141
- fake_audios = fake_audios.float().cpu().numpy()
142
- segments.append(fake_audios)
143
-
144
- if streaming:
145
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
146
-
147
- if len(segments) == 0:
148
- return (
149
- None,
150
- None,
151
- build_html_error_message(
152
- i18n("No audio generated, please check the input text.")
153
- ),
154
- )
155
-
156
- # No matter streaming or not, we need to return the final audio
157
- audio = np.concatenate(segments, axis=0)
158
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
159
-
160
- if torch.cuda.is_available(): # 如果没有 GPU,则不执行此部分
161
- torch.cuda.empty_cache()
162
- gc.collect()
163
-
164
-
165
- def inference_with_auto_rerank(
166
- text,
167
- enable_reference_audio,
168
- reference_audio,
169
- reference_text,
170
- max_new_tokens,
171
- chunk_length,
172
- top_p,
173
- repetition_penalty,
174
- temperature,
175
- use_auto_rerank,
176
- streaming=False,
177
- ):
178
-
179
- max_attempts = 2 if use_auto_rerank else 1
180
- best_wer = float("inf")
181
- best_audio = None
182
- best_sample_rate = None
183
-
184
- for attempt in range(max_attempts):
185
- audio_generator = inference(
186
- text,
187
- enable_reference_audio,
188
- reference_audio,
189
- reference_text,
190
- max_new_tokens,
191
- chunk_length,
192
- top_p,
193
- repetition_penalty,
194
- temperature,
195
- streaming=False,
196
- )
197
-
198
- # 获取音频数据
199
- for _ in audio_generator:
200
- pass
201
- _, (sample_rate, audio), message = _
202
-
203
- if audio is None:
204
- return None, None, message
205
-
206
- if not use_auto_rerank:
207
- return None, (sample_rate, audio), None
208
-
209
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
210
- wer = calculate_wer(text, asr_result["text"])
211
- if wer <= 0.3 and not asr_result["huge_gap"]:
212
- return None, (sample_rate, audio), None
213
-
214
- if wer < best_wer:
215
- best_wer = wer
216
- best_audio = audio
217
- best_sample_rate = sample_rate
218
-
219
- if attempt == max_attempts - 1:
220
- break
221
-
222
- return None, (best_sample_rate, best_audio), None
223
-
224
-
225
- inference_stream = partial(inference, streaming=True)
226
-
227
- n_audios = 4
228
-
229
- global_audio_list = []
230
- global_error_list = []
231
-
232
-
233
- def inference_wrapper(
234
- text,
235
- enable_reference_audio,
236
- reference_audio,
237
- reference_text,
238
- max_new_tokens,
239
- chunk_length,
240
- top_p,
241
- repetition_penalty,
242
- temperature,
243
- batch_infer_num,
244
- if_load_asr_model,
245
- ):
246
- audios = []
247
- errors = []
248
-
249
- for _ in range(batch_infer_num):
250
- result = inference_with_auto_rerank(
251
- text,
252
- enable_reference_audio,
253
- reference_audio,
254
- reference_text,
255
- max_new_tokens,
256
- chunk_length,
257
- top_p,
258
- repetition_penalty,
259
- temperature,
260
- if_load_asr_model,
261
- )
262
-
263
- _, audio_data, error_message = result
264
-
265
- audios.append(
266
- gr.Audio(value=audio_data if audio_data else None, visible=True),
267
- )
268
- errors.append(
269
- gr.HTML(value=error_message if error_message else None, visible=True),
270
- )
271
-
272
- for _ in range(batch_infer_num, n_audios):
273
- audios.append(
274
- gr.Audio(value=None, visible=False),
275
- )
276
- errors.append(
277
- gr.HTML(value=None, visible=False),
278
- )
279
-
280
- return None, *audios, *errors
281
-
282
-
283
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
284
- buffer = io.BytesIO()
285
-
286
- with wave.open(buffer, "wb") as wav_file:
287
- wav_file.setnchannels(channels)
288
- wav_file.setsampwidth(bit_depth // 8)
289
- wav_file.setframerate(sample_rate)
290
-
291
- wav_header_bytes = buffer.getvalue()
292
- buffer.close()
293
- return wav_header_bytes
294
-
295
-
296
- def normalize_text(user_input, use_normalization):
297
- if use_normalization:
298
- return ChnNormedText(raw_text=user_input).normalize()
299
- else:
300
- return user_input
301
-
302
-
303
- asr_model = None
304
-
305
-
306
- def change_if_load_asr_model(if_load):
307
- global asr_model
308
-
309
- if if_load:
310
- gr.Warning("Loading faster whisper model...")
311
- if asr_model is None:
312
- asr_model = load_model()
313
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
314
-
315
- if if_load is False:
316
- gr.Warning("Unloading faster whisper model...")
317
- del asr_model
318
- asr_model = None
319
- if torch.cuda.is_available(): # 如果没有 GPU,则不执行此部分
320
- torch.cuda.empty_cache()
321
- gc.collect()
322
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
323
-
324
-
325
- def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
326
- if if_load and asr_model is not None:
327
- if (
328
- if_auto_label
329
- and enable_ref
330
- and ref_audio
331
- and ref_text.strip() == ""
332
- ):
333
- ref_text = batch_asr(asr_model, [ref_audio])[0]["text"]
334
- return ref_text
335
- return ref_text
336
-
337
-
338
- def setup_gradio_interface():
339
- with gr.Blocks() as app:
340
- gr.Markdown(HEADER_MD)
341
-
342
- with gr.Row():
343
- with gr.Column(scale=2):
344
- text_box = gr.Textbox(
345
- label=i18n("Input Text"),
346
- placeholder=TEXTBOX_PLACEHOLDER,
347
- max_lines=6,
348
- )
349
- normalization_checkbox = gr.Checkbox(
350
- label=i18n("Enable Text Normalization"),
351
- value=False,
352
- )
353
-
354
- reference_audio_file = gr.Audio(
355
- label=i18n("Reference Audio"),
356
- type="filepath",
357
- source="upload",
358
- interactive=True,
359
- )
360
-
361
- reference_text_box = gr.Textbox(
362
- label=i18n("Reference Text"),
363
- placeholder=i18n("Put your reference text here."),
364
- max_lines=3,
365
- )
366
-
367
- with gr.Row():
368
- max_new_tokens_input = gr.Slider(
369
- label=i18n("Max New Tokens"),
370
- minimum=1,
371
- maximum=200,
372
- value=60,
373
- step=1,
374
- )
375
- chunk_length_input = gr.Slider(
376
- label=i18n("Chunk Length"),
377
- minimum=0,
378
- maximum=20,
379
- value=0,
380
- step=1,
381
- )
382
-
383
- with gr.Row():
384
- temperature_input = gr.Slider(
385
- label=i18n("Temperature"),
386
- minimum=0,
387
- maximum=1,
388
- value=0.7,
389
- step=0.01,
390
- )
391
- repetition_penalty_input = gr.Slider(
392
- label=i18n("Repetition Penalty"),
393
- minimum=0,
394
- maximum=2,
395
- value=1,
396
- step=0.01,
397
- )
398
- top_p_input = gr.Slider(
399
- label=i18n("Top P"),
400
- minimum=0,
401
- maximum=1,
402
- value=0.9,
403
- step=0.01,
404
- )
405
-
406
- with gr.Row():
407
- load_asr_model_checkbox = gr.Checkbox(
408
- label=i18n("Load ASR Model"),
409
- value=False,
410
- )
411
- auto_label_checkbox = gr.Checkbox(
412
- label=i18n("Auto Labeling"),
413
- value=False,
414
- )
415
-
416
- with gr.Column(scale=1):
417
- submit_btn = gr.Button(i18n("Submit"))
418
-
419
- output_audio = gr.Audio(label=i18n("Generated Audio"))
420
- output_error = gr.HTML(label=i18n("Error Message"))
421
-
422
- submit_btn.click(
423
- inference_wrapper,
424
- inputs=[
425
- text_box,
426
- load_asr_model_checkbox,
427
- reference_audio_file,
428
- reference_text_box,
429
- max_new_tokens_input,
430
- chunk_length_input,
431
- top_p_input,
432
- repetition_penalty_input,
433
- temperature_input,
434
- gr.Slider(value=n_audios, visible=False),
435
- ],
436
- outputs=[output_error, output_audio],
437
- )
438
-
439
- # Interface to reload ASR model
440
- load_asr_model_checkbox.change(
441
- change_if_load_asr_model,
442
- inputs=[load_asr_model_checkbox],
443
- outputs=[load_asr_model_checkbox],
444
- )
445
-
446
- # Interface for auto labeling
447
- auto_label_checkbox.change(
448
- change_if_auto_label,
449
- inputs=[
450
- auto_label_checkbox,
451
- load_asr_model_checkbox,
452
- reference_audio_file,
453
- reference_text_box,
454
- ],
455
- outputs=[reference_text_box],
456
- )
457
-
458
- app.launch()
459
-
460
-
461
- if __name__ == "__main__":
462
- parser = ArgumentParser()
463
- parser.add_argument(
464
- "--max-gradio-length",
465
- type=int,
466
- default=2048,
467
- help="Maximum length of input text for Gradio.",
468
- )
469
- parser.add_argument(
470
- "--compile",
471
- action="store_true",
472
- help="Compile the model.",
473
- )
474
- parser.add_argument(
475
- "--precision",
476
- type=str,
477
- default="float32",
478
- help="Model precision, one of ['float16', 'float32', 'bfloat16'].",
479
- )
480
- parser.add_argument(
481
- "--llama-checkpoint-path",
482
- type=str,
483
- required=True,
484
- help="Path to the Llama checkpoint.",
485
- )
486
- parser.add_argument(
487
- "--decoder-checkpoint-path",
488
- type=str,
489
- required=True,
490
- help="Path to the decoder checkpoint.",
491
- )
492
- parser.add_argument(
493
- "--decoder-config-name",
494
- type=str,
495
- required=True,
496
- help="Name of the decoder config.",
497
- )
498
- parser.add_argument(
499
- "--device",
500
- type=str,
501
- default="cpu",
502
- help="Device to run the model on, one of ['cpu', 'cuda'].",
503
- )
504
  args = parser.parse_args()
505
 
506
- llama_checkpoint_path = args.llama_checkpoint_path
507
- decoder_checkpoint_path = args.decoder_checkpoint_path
508
- decoder_config_name = args.decoder_config_name
509
- device = args.device
 
 
 
 
510
 
511
- logger.info("Loading decoder model...")
512
- decoder_model = load_decoder_model()
513
 
514
- # Initialize Llama and ASR models
515
- llama_queue = launch_thread_safe_queue()
516
- logger.info("Loading Llama model...")
517
- load_model(0)
518
-
519
- # Setup the Gradio interface
520
- setup_gradio_interface()
 
1
+ import argparse
2
+ import subprocess
3
+
4
+ def main():
5
+ # 创建命令行参数解析器
6
+ parser = argparse.ArgumentParser(description="启动 Fish Speech 应用")
7
+ parser.add_argument("--llama-checkpoint-path", type=str, required=True, help="Llama 检查点路径")
8
+ parser.add_argument("--decoder-checkpoint-path", type=str, required=True, help="解码器检查点路径")
9
+ parser.add_argument("--decoder-config-name", type=str, required=True, help="解码器配置名称")
10
+ parser.add_argument("--device", type=str, default="cpu", help="设备类型(cpu 或 cuda)")
11
+
12
+ # 解析参数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  args = parser.parse_args()
14
 
15
+ # 构造命令
16
+ command = [
17
+ "python", "tools/webui.py",
18
+ "--llama-checkpoint-path", args.llama_checkpoint_path,
19
+ "--decoder-checkpoint-path", args.decoder_checkpoint_path,
20
+ "--decoder-config-name", args.decoder_config_name,
21
+ "--device", args.device
22
+ ]
23
 
24
+ # 运行命令
25
+ subprocess.run(command)
26
 
27
+ if __name__ == "__main__":
28
+ main()