uto1125 commited on
Commit
a3901af
·
verified ·
1 Parent(s): 5383c92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +453 -99
app.py CHANGED
@@ -1,63 +1,230 @@
 
 
 
1
  import os
2
- import torch
 
3
  from argparse import ArgumentParser
 
 
 
 
 
 
 
 
4
  from loguru import logger
5
- from tools.llama.generate import launch_thread_safe_queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from tools.vqgan.inference import load_model as load_decoder_model
7
- import gradio as gr # 导入 Gradio
8
 
9
- def parse_args():
10
- parser = ArgumentParser()
11
- parser.add_argument(
12
- "--llama-checkpoint-path",
13
- type=str,
14
- default="checkpoints/fish-speech-1.4-sft-yth-lora",
15
- help="Path to the Llama checkpoint"
16
- )
17
- parser.add_argument(
18
- "--decoder-checkpoint-path",
19
- type=str,
20
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
21
- help="Path to the VQ-GAN checkpoint"
22
- )
23
- parser.add_argument(
24
- "--decoder-config-name",
25
- type=str,
26
- default="firefly_gan_vq",
27
- help="VQ-GAN config name"
28
- )
29
- parser.add_argument(
30
- "--device",
31
- type=str,
32
- default="cpu",
33
- help="Device to run on (cpu or cuda)"
34
- )
35
- parser.add_argument(
36
- "--half",
37
- action="store_true",
38
- help="Use half precision"
39
- )
40
- parser.add_argument(
41
- "--compile",
42
- action="store_true",
43
- default=True,
44
- help="Compile the model for optimized inference"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
- parser.add_argument(
47
- "--max-gradio-length",
48
- type=int,
49
- default=0,
50
- help="Maximum length for Gradio input"
 
 
 
 
 
 
 
 
 
 
51
  )
52
- parser.add_argument(
53
- "--theme",
54
- type=str,
55
- default="light",
56
- help="Theme for the Gradio app"
 
 
57
  )
58
- return parser.parse_args()
59
 
60
- def inference(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  text,
62
  enable_reference_audio,
63
  reference_audio,
@@ -67,66 +234,253 @@ def inference(
67
  top_p,
68
  repetition_penalty,
69
  temperature,
 
 
70
  ):
71
- logger.info(f"Running inference on: {text}")
72
- # 模拟推理过程
73
- result = f"Processed text: {text}"
74
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def inference_function(text):
77
- return f"Processed: {text}"
78
 
79
- def build_app(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  with gr.Blocks() as app:
81
- gr.Markdown(f"# Fish Speech Inference - Theme: {args.theme}")
82
- text_input = gr.Textbox(label="Input Text")
83
- output = gr.Textbox(label="Output Text")
84
- submit_button = gr.Button("Submit")
85
 
86
- submit_button.click(fn=inference_function, inputs=text_input, outputs=output)
87
- return app
 
 
 
 
 
 
 
 
 
88
 
89
- def main():
90
- args = parse_args()
 
 
 
 
91
 
92
- args.precision = torch.half if args.half else torch.bfloat16
 
 
 
 
93
 
94
- logger.info("Loading Llama model...")
95
- llama_queue = launch_thread_safe_queue(
96
- checkpoint_path=args.llama_checkpoint_path,
97
- device=args.device,
98
- precision=args.precision,
99
- compile=args.compile,
100
- )
101
- logger.info("Llama model loaded, loading VQ-GAN model...")
 
 
 
 
 
 
 
102
 
103
- decoder_model = load_decoder_model(
104
- config_name=args.decoder_config_name,
105
- checkpoint_path=args.decoder_checkpoint_path,
106
- device=args.device,
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- logger.info("Decoder model loaded, warming up...")
110
-
111
- # Perform a dry run to warm up the model
112
- inference(
113
- text="Hello, world!",
114
- enable_reference_audio=False,
115
- reference_audio=None,
116
- reference_text="",
117
- max_new_tokens=0,
118
- chunk_length=100,
119
- top_p=0.7,
120
- repetition_penalty=1.2,
121
- temperature=0.7,
122
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- logger.info("Warming up done, launching the web UI...")
 
 
 
 
 
 
 
 
 
 
125
 
126
- # Launch the Gradio app, passing args to build_app
127
- app = build_app(args)
128
- app.launch(show_api=True)
129
 
130
 
131
  if __name__ == "__main__":
132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import html
3
+ import io
4
  import os
5
+ import queue
6
+ import wave
7
  from argparse import ArgumentParser
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import librosa
13
+ import numpy as np
14
+ import pyrootutils
15
+ import torch
16
  from loguru import logger
17
+ from transformers import AutoTokenizer
18
+
19
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+ from fish_speech.i18n import i18n
22
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
23
+ from fish_speech.utils import autocast_exclude_mps
24
+ from tools.api import decode_vq_tokens, encode_reference
25
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
26
+ from tools.llama.generate import (
27
+ GenerateRequest,
28
+ GenerateResponse,
29
+ WrappedGenerateResponse,
30
+ launch_thread_safe_queue,
31
+ )
32
  from tools.vqgan.inference import load_model as load_decoder_model
 
33
 
34
+ # Make einx happy
35
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
36
+
37
+ HEADER_MD = f"""# Fish Speech
38
+
39
+ {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
40
+
41
+ {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
42
+
43
+ {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
44
+
45
+ {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
46
+ """
47
+
48
+ TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
49
+ SPACE_IMPORTED = False
50
+
51
+
52
+ def build_html_error_message(error):
53
+ return f"""
54
+ <div style="color: red;
55
+ font-weight: bold;">
56
+ {html.escape(str(error))}
57
+ </div>
58
+ """
59
+
60
+
61
+ @torch.inference_mode()
62
+ def inference(
63
+ text,
64
+ enable_reference_audio,
65
+ reference_audio,
66
+ reference_text,
67
+ max_new_tokens,
68
+ chunk_length,
69
+ top_p,
70
+ repetition_penalty,
71
+ temperature,
72
+ streaming=False,
73
+ ):
74
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
75
+ return (
76
+ None,
77
+ None,
78
+ i18n("Text is too long, please keep it under {} characters.").format(
79
+ args.max_gradio_length
80
+ ),
81
+ )
82
+
83
+ # Parse reference audio aka prompt
84
+ prompt_tokens = encode_reference(
85
+ decoder_model=decoder_model,
86
+ reference_audio=reference_audio,
87
+ enable_reference_audio=enable_reference_audio,
88
  )
89
+
90
+ # LLAMA Inference
91
+ request = dict(
92
+ device="cpu", # 设置为 CPU
93
+ max_new_tokens=max_new_tokens,
94
+ text=text,
95
+ top_p=top_p,
96
+ repetition_penalty=repetition_penalty,
97
+ temperature=temperature,
98
+ compile=args.compile,
99
+ iterative_prompt=chunk_length > 0,
100
+ chunk_length=chunk_length,
101
+ max_length=2048,
102
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
103
+ prompt_text=reference_text if enable_reference_audio else None,
104
  )
105
+
106
+ response_queue = queue.Queue()
107
+ llama_queue.put(
108
+ GenerateRequest(
109
+ request=request,
110
+ response_queue=response_queue,
111
+ )
112
  )
 
113
 
114
+ if streaming:
115
+ yield wav_chunk_header(), None, None
116
+
117
+ segments = []
118
+
119
+ while True:
120
+ result: WrappedGenerateResponse = response_queue.get()
121
+ if result.status == "error":
122
+ yield None, None, build_html_error_message(result.response)
123
+ break
124
+
125
+ result: GenerateResponse = result.response
126
+ if result.action == "next":
127
+ break
128
+
129
+ with autocast_exclude_mps(device_type="cpu", dtype=args.precision): # 设置为 CPU
130
+ fake_audios = decode_vq_tokens(
131
+ decoder_model=decoder_model,
132
+ codes=result.codes,
133
+ )
134
+
135
+ fake_audios = fake_audios.float().cpu().numpy()
136
+ segments.append(fake_audios)
137
+
138
+ if streaming:
139
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
140
+
141
+ if len(segments) == 0:
142
+ return (
143
+ None,
144
+ None,
145
+ build_html_error_message(
146
+ i18n("No audio generated, please check the input text.")
147
+ ),
148
+ )
149
+
150
+ # No matter streaming or not, we need to return the final audio
151
+ audio = np.concatenate(segments, axis=0)
152
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
153
+
154
+ if torch.cuda.is_available(): # 如果没有 GPU,则不执行此部分
155
+ torch.cuda.empty_cache()
156
+ gc.collect()
157
+
158
+
159
+ def inference_with_auto_rerank(
160
+ text,
161
+ enable_reference_audio,
162
+ reference_audio,
163
+ reference_text,
164
+ max_new_tokens,
165
+ chunk_length,
166
+ top_p,
167
+ repetition_penalty,
168
+ temperature,
169
+ use_auto_rerank,
170
+ streaming=False,
171
+ ):
172
+
173
+ max_attempts = 2 if use_auto_rerank else 1
174
+ best_wer = float("inf")
175
+ best_audio = None
176
+ best_sample_rate = None
177
+
178
+ for attempt in range(max_attempts):
179
+ audio_generator = inference(
180
+ text,
181
+ enable_reference_audio,
182
+ reference_audio,
183
+ reference_text,
184
+ max_new_tokens,
185
+ chunk_length,
186
+ top_p,
187
+ repetition_penalty,
188
+ temperature,
189
+ streaming=False,
190
+ )
191
+
192
+ # 获取音频数据
193
+ for _ in audio_generator:
194
+ pass
195
+ _, (sample_rate, audio), message = _
196
+
197
+ if audio is None:
198
+ return None, None, message
199
+
200
+ if not use_auto_rerank:
201
+ return None, (sample_rate, audio), None
202
+
203
+ asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
204
+ wer = calculate_wer(text, asr_result["text"])
205
+ if wer <= 0.3 and not asr_result["huge_gap"]:
206
+ return None, (sample_rate, audio), None
207
+
208
+ if wer < best_wer:
209
+ best_wer = wer
210
+ best_audio = audio
211
+ best_sample_rate = sample_rate
212
+
213
+ if attempt == max_attempts - 1:
214
+ break
215
+
216
+ return None, (best_sample_rate, best_audio), None
217
+
218
+
219
+ inference_stream = partial(inference, streaming=True)
220
+
221
+ n_audios = 4
222
+
223
+ global_audio_list = []
224
+ global_error_list = []
225
+
226
+
227
+ def inference_wrapper(
228
  text,
229
  enable_reference_audio,
230
  reference_audio,
 
234
  top_p,
235
  repetition_penalty,
236
  temperature,
237
+ batch_infer_num,
238
+ if_load_asr_model,
239
  ):
240
+ audios = []
241
+ errors = []
242
+
243
+ for _ in range(batch_infer_num):
244
+ result = inference_with_auto_rerank(
245
+ text,
246
+ enable_reference_audio,
247
+ reference_audio,
248
+ reference_text,
249
+ max_new_tokens,
250
+ chunk_length,
251
+ top_p,
252
+ repetition_penalty,
253
+ temperature,
254
+ if_load_asr_model,
255
+ )
256
+
257
+ _, audio_data, error_message = result
258
+
259
+ audios.append(
260
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
261
+ )
262
+ errors.append(
263
+ gr.HTML(value=error_message if error_message else None, visible=True),
264
+ )
265
+
266
+ for _ in range(batch_infer_num, n_audios):
267
+ audios.append(
268
+ gr.Audio(value=None, visible=False),
269
+ )
270
+ errors.append(
271
+ gr.HTML(value=None, visible=False),
272
+ )
273
+
274
+ return None, *audios, *errors
275
+
276
+
277
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
278
+ buffer = io.BytesIO()
279
+
280
+ with wave.open(buffer, "wb") as wav_file:
281
+ wav_file.setnchannels(channels)
282
+ wav_file.setsampwidth(bit_depth // 8)
283
+ wav_file.setframerate(sample_rate)
284
+
285
+ wav_header_bytes = buffer.getvalue()
286
+ buffer.close()
287
+ return wav_header_bytes
288
+
289
+
290
+ def normalize_text(user_input, use_normalization):
291
+ if use_normalization:
292
+ return ChnNormedText(raw_text=user_input).normalize()
293
+ else:
294
+ return user_input
295
 
 
 
296
 
297
+ asr_model = None
298
+
299
+
300
+ def change_if_load_asr_model(if_load):
301
+ global asr_model
302
+
303
+ if if_load:
304
+ gr.Warning("Loading faster whisper model...")
305
+ if asr_model is None:
306
+ asr_model = load_model()
307
+ return gr.Checkbox(label="Unload faster whisper model", value=if_load)
308
+
309
+ if if_load is False:
310
+ gr.Warning("Unloading faster whisper model...")
311
+ del asr_model
312
+ asr_model = None
313
+ if torch.cuda.is_available(): # 如果没有 GPU,则不执行此部分
314
+ torch.cuda.empty_cache()
315
+ gc.collect()
316
+ return gr.Checkbox(label="Load faster whisper model", value=if_load)
317
+
318
+
319
+ def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
320
+ if if_load and asr_model is not None:
321
+ if (
322
+ if_auto_label
323
+ and enable_ref
324
+ and ref_audio is not None
325
+ and ref_text.strip() == ""
326
+ ):
327
+ data, sample_rate = librosa.load(ref_audio)
328
+ res = batch_asr(asr_model, [data], sample_rate)[0]
329
+ ref_text = res["text"]
330
+ return ref_text
331
+
332
+
333
+ def setup_gradio_interface():
334
  with gr.Blocks() as app:
335
+ gr.Markdown(HEADER_MD)
 
 
 
336
 
337
+ with gr.Row():
338
+ with gr.Column(scale=2):
339
+ text_box = gr.Textbox(
340
+ label=i18n("Input Text"),
341
+ placeholder=TEXTBOX_PLACEHOLDER,
342
+ max_lines=6,
343
+ )
344
+ normalization_checkbox = gr.Checkbox(
345
+ label=i18n("Enable Text Normalization"),
346
+ value=False,
347
+ )
348
 
349
+ reference_audio_file = gr.Audio(
350
+ label=i18n("Reference Audio"),
351
+ type="filepath",
352
+ source="upload",
353
+ interactive=True,
354
+ )
355
 
356
+ reference_text_box = gr.Textbox(
357
+ label=i18n("Reference Text"),
358
+ placeholder=i18n("Put your reference text here."),
359
+ max_lines=3,
360
+ )
361
 
362
+ with gr.Row():
363
+ max_new_tokens_input = gr.Slider(
364
+ label=i18n("Max New Tokens"),
365
+ minimum=1,
366
+ maximum=200,
367
+ value=60,
368
+ step=1,
369
+ )
370
+ chunk_length_input = gr.Slider(
371
+ label=i18n("Chunk Length"),
372
+ minimum=0,
373
+ maximum=20,
374
+ value=0,
375
+ step=1,
376
+ )
377
 
378
+ with gr.Row():
379
+ temperature_input = gr.Slider(
380
+ label=i18n("Temperature"),
381
+ minimum=0,
382
+ maximum=1,
383
+ value=0.7,
384
+ step=0.01,
385
+ )
386
+ repetition_penalty_input = gr.Slider(
387
+ label=i18n("Repetition Penalty"),
388
+ minimum=0,
389
+ maximum=2,
390
+ value=1,
391
+ step=0.01,
392
+ )
393
+ top_p_input = gr.Slider(
394
+ label=i18n("Top P"),
395
+ minimum=0,
396
+ maximum=1,
397
+ value=0.9,
398
+ step=0.01,
399
+ )
400
 
401
+ with gr.Row():
402
+ load_asr_model_checkbox = gr.Checkbox(
403
+ label=i18n("Load ASR Model"),
404
+ value=False,
405
+ )
406
+ auto_label_checkbox = gr.Checkbox(
407
+ label=i18n("Auto Labeling"),
408
+ value=False,
409
+ )
410
+
411
+ with gr.Column(scale=1):
412
+ submit_btn = gr.Button(i18n("Submit"))
413
+
414
+ output_audio = gr.Audio(label=i18n("Generated Audio"))
415
+ output_error = gr.HTML(label=i18n("Error Message"))
416
+
417
+ submit_btn.click(
418
+ inference_wrapper,
419
+ inputs=[
420
+ text_box,
421
+ load_asr_model_checkbox,
422
+ reference_audio_file,
423
+ reference_text_box,
424
+ max_new_tokens_input,
425
+ chunk_length_input,
426
+ top_p_input,
427
+ repetition_penalty_input,
428
+ temperature_input,
429
+ gr.Slider(value=n_audios, visible=False),
430
+ ],
431
+ outputs=[output_error, output_audio],
432
+ )
433
+
434
+ # Interface to reload ASR model
435
+ load_asr_model_checkbox.change(
436
+ change_if_load_asr_model,
437
+ inputs=[load_asr_model_checkbox],
438
+ outputs=[load_asr_model_checkbox],
439
+ )
440
 
441
+ # Interface for auto labeling
442
+ auto_label_checkbox.change(
443
+ change_if_auto_label,
444
+ inputs=[
445
+ auto_label_checkbox,
446
+ load_asr_model_checkbox,
447
+ reference_audio_file,
448
+ reference_text_box,
449
+ ],
450
+ outputs=[reference_text_box],
451
+ )
452
 
453
+ app.launch()
 
 
454
 
455
 
456
  if __name__ == "__main__":
457
+ parser = ArgumentParser()
458
+ parser.add_argument(
459
+ "--max-gradio-length",
460
+ type=int,
461
+ default=2048,
462
+ help="Maximum length of input text for Gradio.",
463
+ )
464
+ parser.add_argument(
465
+ "--compile",
466
+ action="store_true",
467
+ help="Compile the model.",
468
+ )
469
+ parser.add_argument(
470
+ "--precision",
471
+ type=str,
472
+ default="float32",
473
+ help="Model precision, one of ['float16', 'float32', 'bfloat16'].",
474
+ )
475
+ args = parser.parse_args()
476
+
477
+ logger.info("Loading decoder model...")
478
+ decoder_model = load_decoder_model()
479
+
480
+ # Initialize Llama and ASR models
481
+ llama_queue = launch_thread_safe_queue()
482
+ logger.info("Loading Llama model...")
483
+ load_model(0)
484
+
485
+ # Setup the Gradio interface
486
+ setup_gradio_interface()