uto1125 commited on
Commit
77e8f11
·
verified ·
1 Parent(s): 77b2a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -608
app.py CHANGED
@@ -1,618 +1,66 @@
1
  import os
2
- import queue
3
- from huggingface_hub import snapshot_download
4
- import hydra
5
- import numpy as np
6
- import wave
7
- import io
8
- import pyrootutils
9
- import gc
10
-
11
- # Download if not exists
12
- os.makedirs("checkpoints", exist_ok=True)
13
- #snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
14
-
15
- print("All checkpoints downloaded")
16
-
17
- import html
18
- import os
19
- import threading
20
- from argparse import ArgumentParser
21
- from pathlib import Path
22
- from functools import partial
23
-
24
- import gradio as gr
25
- import librosa
26
  import torch
27
- import torchaudio
28
- # torch.cuda.is_available = lambda: False
29
- # torchaudio.set_audio_backend("soundfile")
30
-
31
  from loguru import logger
32
- from transformers import AutoTokenizer
33
-
34
  from tools.llama.generate import launch_thread_safe_queue
35
- from tools.vqgan.inference import load_model as load_vqgan_model
36
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
37
- from tools.api import decode_vq_tokens, encode_reference
38
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
39
- from tools.llama.generate import (
40
- GenerateRequest,
41
- GenerateResponse,
42
- WrappedGenerateResponse,
43
- launch_thread_safe_queue,
44
- )
45
  from tools.vqgan.inference import load_model as load_decoder_model
46
 
47
- # Make einx happy
48
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
49
-
50
-
51
- HEADER_MD = """# Fish Speech
52
-
53
- ## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
54
- ## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
55
-
56
- A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
57
- 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
58
-
59
- You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).
60
- 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.4) 找到模型.
61
-
62
- Related code and weights are released under CC BY-NC-SA 4.0 License.
63
- 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
64
-
65
- We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
66
- 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
67
-
68
- The model running in this WebUI is Fish Speech V1.4 Medium.
69
- 在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
70
- """
71
-
72
- TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
73
-
74
- try:
75
- import spaces
76
-
77
- GPU_DECORATOR = spaces.GPU
78
- except ImportError:
79
-
80
- def GPU_DECORATOR(func):
81
- def wrapper(*args, **kwargs):
82
- return func(*args, **kwargs)
83
-
84
- return wrapper
85
-
86
-
87
- def build_html_error_message(error):
88
- return f"""
89
- <div style="color: red;
90
- font-weight: bold;">
91
- {html.escape(error)}
92
- </div>
93
- """
94
-
95
-
96
- @GPU_DECORATOR
97
- @torch.inference_mode()
98
- def inference(
99
- text,
100
- enable_reference_audio,
101
- reference_audio,
102
- reference_text,
103
- max_new_tokens,
104
- chunk_length,
105
- top_p,
106
- repetition_penalty,
107
- temperature,
108
- streaming=False
109
- ):
110
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
111
- return (
112
- None,
113
- None,
114
- "Text is too long, please keep it under {} characters.".format(
115
- args.max_gradio_length
116
- ),
117
- )
118
-
119
- # Parse reference audio aka prompt
120
- prompt_tokens = encode_reference(
121
- decoder_model=decoder_model,
122
- reference_audio=reference_audio,
123
- enable_reference_audio=enable_reference_audio,
124
- )
125
-
126
- # LLAMA Inference
127
- request = dict(
128
- device=decoder_model.device,
129
- max_new_tokens=max_new_tokens,
130
- text=text,
131
- top_p=top_p,
132
- repetition_penalty=repetition_penalty,
133
- temperature=temperature,
134
- compile=args.compile,
135
- iterative_prompt=chunk_length > 0,
136
- chunk_length=chunk_length,
137
- max_length=2048,
138
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
139
- prompt_text=reference_text if enable_reference_audio else None,
140
- )
141
-
142
- response_queue = queue.Queue()
143
- llama_queue.put(
144
- GenerateRequest(
145
- request=request,
146
- response_queue=response_queue,
147
- )
148
- )
149
-
150
- segments = []
151
-
152
- while True:
153
- result: WrappedGenerateResponse = response_queue.get()
154
- if result.status == "error":
155
- return None, None, build_html_error_message(result.response)
156
-
157
- result: GenerateResponse = result.response
158
- if result.action == "next":
159
- break
160
-
161
- with torch.autocast(
162
- device_type=(
163
- "cpu"
164
- if decoder_model.device.type == "mps"
165
- else decoder_model.device.type
166
- ),
167
- dtype=args.precision,
168
- ):
169
- fake_audios = decode_vq_tokens(
170
- decoder_model=decoder_model,
171
- codes=result.codes,
172
- )
173
-
174
- fake_audios = fake_audios.float().cpu().numpy()
175
- segments.append(fake_audios)
176
-
177
- if len(segments) == 0:
178
- return (
179
- None,
180
- None,
181
- build_html_error_message(
182
- "No audio generated, please check the input text."
183
- ),
184
- )
185
-
186
- # Return the final audio
187
- audio = np.concatenate(segments, axis=0)
188
- return None, (decoder_model.spec_transform.sample_rate, audio), None
189
-
190
- if torch.cpu.is_available():
191
- torch.cpu.empty_cache()
192
- gc.collect()
193
-
194
-
195
- def inference_with_auto_rerank(
196
- text,
197
- enable_reference_audio,
198
- reference_audio,
199
- reference_text,
200
- max_new_tokens,
201
- chunk_length,
202
- top_p,
203
- repetition_penalty,
204
- temperature,
205
- use_auto_rerank,
206
- streaming=False,
207
- ):
208
- max_attempts = 2 if use_auto_rerank else 1
209
- best_wer = float("inf")
210
- best_audio = None
211
- best_sample_rate = None
212
-
213
- for attempt in range(max_attempts):
214
- _, (sample_rate, audio), message = inference(
215
- text,
216
- enable_reference_audio,
217
- reference_audio,
218
- reference_text,
219
- max_new_tokens,
220
- chunk_length,
221
- top_p,
222
- repetition_penalty,
223
- temperature,
224
- streaming=False,
225
- )
226
-
227
- if audio is None:
228
- return None, None, message
229
-
230
- if not use_auto_rerank:
231
- return None, (sample_rate, audio), None
232
-
233
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
234
- wer = calculate_wer(text, asr_result["text"])
235
-
236
- if wer <= 0.3 and not asr_result["huge_gap"]:
237
- return None, (sample_rate, audio), None
238
-
239
- if wer < best_wer:
240
- best_wer = wer
241
- best_audio = audio
242
- best_sample_rate = sample_rate
243
-
244
- if attempt == max_attempts - 1:
245
- break
246
-
247
- return None, (best_sample_rate, best_audio), None
248
-
249
-
250
- n_audios = 4
251
-
252
- global_audio_list = []
253
- global_error_list = []
254
-
255
- def inference_wrapper(
256
- text,
257
- enable_reference_audio,
258
- reference_audio,
259
- reference_text,
260
- max_new_tokens,
261
- chunk_length,
262
- top_p,
263
- repetition_penalty,
264
- temperature,
265
- batch_infer_num,
266
- if_load_asr_model,
267
- ):
268
- audios = []
269
- errors = []
270
-
271
- for _ in range(batch_infer_num):
272
- result = inference_with_auto_rerank(
273
- text,
274
- enable_reference_audio,
275
- reference_audio,
276
- reference_text,
277
- max_new_tokens,
278
- chunk_length,
279
- top_p,
280
- repetition_penalty,
281
- temperature,
282
- if_load_asr_model,
283
- )
284
-
285
- _, audio_data, error_message = result
286
-
287
- audios.append(
288
- gr.Audio(value=audio_data if audio_data else None, visible=True),
289
- )
290
- errors.append(
291
- gr.HTML(value=error_message if error_message else None, visible=True),
292
- )
293
-
294
- for _ in range(batch_infer_num, n_audios):
295
- audios.append(
296
- gr.Audio(value=None, visible=False),
297
- )
298
- errors.append(
299
- gr.HTML(value=None, visible=False),
300
- )
301
-
302
- return None, *audios, *errors
303
-
304
-
305
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
306
- buffer = io.BytesIO()
307
-
308
- with wave.open(buffer, "wb") as wav_file:
309
- wav_file.setnchannels(channels)
310
- wav_file.setsampwidth(bit_depth // 8)
311
- wav_file.setframerate(sample_rate)
312
-
313
- wav_header_bytes = buffer.getvalue()
314
- buffer.close()
315
- return wav_header_bytes
316
-
317
-
318
- def normalize_text(user_input, use_normalization):
319
- if use_normalization:
320
- return ChnNormedText(raw_text=user_input).normalize()
321
- else:
322
- return user_input
323
-
324
-
325
- asr_model = None
326
-
327
-
328
- def change_if_load_asr_model(if_load):
329
- global asr_model
330
-
331
- if if_load:
332
- gr.Warning("Loading faster whisper model...")
333
- if asr_model is None:
334
- asr_model = load_model()
335
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
336
-
337
- if if_load is False:
338
- gr.Warning("Unloading faster whisper model...")
339
- del asr_model
340
- asr_model = None
341
- if torch.cpu.is_available():
342
- torch.cpu.empty_cache()
343
- gc.collect()
344
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
345
-
346
-
347
- def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
348
- if if_load and asr_model is not None:
349
- if (
350
- if_auto_label
351
- and enable_ref
352
- and ref_audio is not None
353
- and ref_text.strip() == ""
354
- ):
355
- data, sample_rate = librosa.load(ref_audio)
356
- res = batch_asr(asr_model, [data], sample_rate)[0]
357
- ref_text = res["text"]
358
- else:
359
- gr.Warning("Whisper model not loaded!")
360
-
361
- return gr.Textbox(value=ref_text)
362
-
363
-
364
- def build_app():
365
- with gr.Blocks(theme=gr.themes.Base()) as app:
366
- gr.Markdown(HEADER_MD)
367
-
368
- # Use light theme by default
369
- app.load(
370
- None,
371
- None,
372
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
373
- % args.theme,
374
- )
375
-
376
- # Inference
377
- with gr.Row():
378
- with gr.Column(scale=3):
379
- text = gr.Textbox(
380
- label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
381
- )
382
- refined_text = gr.Textbox(
383
- label="Realtime Transform Text",
384
- placeholder=
385
- "Normalization Result Preview (Currently Only Chinese)",
386
- lines=5,
387
- interactive=False,
388
- )
389
-
390
- with gr.Row():
391
- if_refine_text = gr.Checkbox(
392
- label="Text Normalization (ZH)",
393
- value=False,
394
- scale=1,
395
- )
396
-
397
- if_load_asr_model = gr.Checkbox(
398
- label="Load / Unload ASR model for auto-reranking",
399
- value=False,
400
- scale=3,
401
- )
402
-
403
- with gr.Row():
404
- with gr.Tab(label="Advanced Config"):
405
- chunk_length = gr.Slider(
406
- label="Iterative Prompt Length, 0 means off",
407
- minimum=0,
408
- maximum=500,
409
- value=200,
410
- step=8,
411
- )
412
-
413
- max_new_tokens = gr.Slider(
414
- label="Maximum tokens per batch, 0 means no limit",
415
- minimum=0,
416
- maximum=2048,
417
- value=1024, # 0 means no limit
418
- step=8,
419
- )
420
-
421
- top_p = gr.Slider(
422
- label="Top-P",
423
- minimum=0.6,
424
- maximum=0.9,
425
- value=0.7,
426
- step=0.01,
427
- )
428
-
429
- repetition_penalty = gr.Slider(
430
- label="Repetition Penalty",
431
- minimum=1,
432
- maximum=1.5,
433
- value=1.2,
434
- step=0.01,
435
- )
436
-
437
- temperature = gr.Slider(
438
- label="Temperature",
439
- minimum=0.6,
440
- maximum=0.9,
441
- value=0.7,
442
- step=0.01,
443
- )
444
-
445
- with gr.Tab(label="Reference Audio"):
446
- gr.Markdown(
447
- "5 to 10 seconds of reference audio, useful for specifying speaker."
448
- )
449
-
450
- enable_reference_audio = gr.Checkbox(
451
- label="Enable Reference Audio",
452
- )
453
-
454
- # Add dropdown for selecting example audio files
455
- example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
456
- example_audio_dropdown = gr.Dropdown(
457
- label="Select Example Audio",
458
- choices=[""] + example_audio_files,
459
- value=""
460
- )
461
-
462
- reference_audio = gr.Audio(
463
- label="Reference Audio",
464
- type="filepath",
465
- )
466
- with gr.Row():
467
- if_auto_label = gr.Checkbox(
468
- label="Auto Labeling",
469
- min_width=100,
470
- scale=0,
471
- value=False,
472
- )
473
- reference_text = gr.Textbox(
474
- label="Reference Text",
475
- lines=1,
476
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
477
- value="",
478
- )
479
- with gr.Tab(label="Batch Inference"):
480
- batch_infer_num = gr.Slider(
481
- label="Batch infer nums",
482
- minimum=1,
483
- maximum=n_audios,
484
- step=1,
485
- value=1,
486
- )
487
-
488
- with gr.Column(scale=3):
489
- for _ in range(n_audios):
490
- with gr.Row():
491
- error = gr.HTML(
492
- label="Error Message",
493
- visible=True if _ == 0 else False,
494
- )
495
- global_error_list.append(error)
496
- with gr.Row():
497
- audio = gr.Audio(
498
- label="Generated Audio",
499
- type="numpy",
500
- interactive=False,
501
- visible=True if _ == 0 else False,
502
- )
503
- global_audio_list.append(audio)
504
-
505
- with gr.Row():
506
- stream_audio = gr.Audio(
507
- label="Streaming Audio",
508
- streaming=True,
509
- autoplay=True,
510
- interactive=False,
511
- show_download_button=True,
512
- )
513
- with gr.Row():
514
- with gr.Column(scale=3):
515
- generate = gr.Button(
516
- value="\U0001F3A7 " + "Generate", variant="primary"
517
- )
518
- generate_stream = gr.Button(
519
- value="\U0001F3A7 " + "Streaming Generate",
520
- variant="primary",
521
- )
522
-
523
- text.input(
524
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
525
- )
526
-
527
- if_load_asr_model.change(
528
- fn=change_if_load_asr_model,
529
- inputs=[if_load_asr_model],
530
- outputs=[if_load_asr_model],
531
- )
532
-
533
- if_auto_label.change(
534
- fn=lambda: gr.Textbox(value=""),
535
- inputs=[],
536
- outputs=[reference_text],
537
- ).then(
538
- fn=change_if_auto_label,
539
- inputs=[
540
- if_load_asr_model,
541
- if_auto_label,
542
- enable_reference_audio,
543
- reference_audio,
544
- reference_text,
545
- ],
546
- outputs=[reference_text],
547
- )
548
-
549
- def select_example_audio(audio_file):
550
- if audio_file:
551
- audio_path = os.path.join("examples", audio_file)
552
- lab_file = os.path.splitext(audio_file)[0] + ".lab"
553
- lab_path = os.path.join("examples", lab_file)
554
-
555
- if os.path.exists(lab_path):
556
- with open(lab_path, "r", encoding="utf-8") as f:
557
- lab_content = f.read().strip()
558
- else:
559
- lab_content = ""
560
-
561
- return audio_path, lab_content, True
562
- return None, "", False
563
-
564
- # Connect the dropdown to update reference audio and text
565
- example_audio_dropdown.change(
566
- fn=select_example_audio,
567
- inputs=[example_audio_dropdown],
568
- outputs=[reference_audio, reference_text, enable_reference_audio]
569
- )
570
- # # Submit
571
- generate.click(
572
- inference_wrapper,
573
- [
574
- refined_text,
575
- enable_reference_audio,
576
- reference_audio,
577
- reference_text,
578
- max_new_tokens,
579
- chunk_length,
580
- top_p,
581
- repetition_penalty,
582
- temperature,
583
- batch_infer_num,
584
- if_load_asr_model,
585
- ],
586
- [stream_audio, *global_audio_list, *global_error_list],
587
- concurrency_limit=1,
588
- )
589
- return app
590
-
591
 
592
  def parse_args():
593
  parser = ArgumentParser()
594
  parser.add_argument(
595
  "--llama-checkpoint-path",
596
- type=Path,
597
- default="checkpoints/fish-speech-1.4",
 
598
  )
599
  parser.add_argument(
600
  "--decoder-checkpoint-path",
601
- type=Path,
602
  default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  )
604
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
605
- parser.add_argument("--device", type=str, default="cuda")
606
- parser.add_argument("--half", action="store_true")
607
- parser.add_argument("--compile", action="store_true",default=True)
608
- parser.add_argument("--max-gradio-length", type=int, default=0)
609
- parser.add_argument("--theme", type=str, default="light")
610
-
611
  return parser.parse_args()
612
 
613
 
614
- if __name__ == "__main__":
615
  args = parse_args()
 
616
  args.precision = torch.half if args.half else torch.bfloat16
617
 
618
  logger.info("Loading Llama model...")
@@ -632,22 +80,25 @@ if __name__ == "__main__":
632
 
633
  logger.info("Decoder model loaded, warming up...")
634
 
635
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
636
- list(
637
- inference(
638
- text="Hello, world!",
639
- enable_reference_audio=False,
640
- reference_audio=None,
641
- reference_text="",
642
- max_new_tokens=0,
643
- chunk_length=100,
644
- top_p=0.7,
645
- repetition_penalty=1.2,
646
- temperature=0.7,
647
- )
648
  )
649
 
650
  logger.info("Warming up done, launching the web UI...")
651
 
 
652
  app = build_app()
653
  app.launch(show_api=True)
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ from argparse import ArgumentParser
 
 
 
4
  from loguru import logger
 
 
5
  from tools.llama.generate import launch_thread_safe_queue
 
 
 
 
 
 
 
 
 
 
6
  from tools.vqgan.inference import load_model as load_decoder_model
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def parse_args():
10
  parser = ArgumentParser()
11
  parser.add_argument(
12
  "--llama-checkpoint-path",
13
+ type=str,
14
+ default="checkpoints/fish-speech-1.4-sft-yth-lora",
15
+ help="Path to the Llama checkpoint"
16
  )
17
  parser.add_argument(
18
  "--decoder-checkpoint-path",
19
+ type=str,
20
  default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
21
+ help="Path to the VQ-GAN checkpoint"
22
+ )
23
+ parser.add_argument(
24
+ "--decoder-config-name",
25
+ type=str,
26
+ default="firefly_gan_vq",
27
+ help="VQ-GAN config name"
28
+ )
29
+ parser.add_argument(
30
+ "--device",
31
+ type=str,
32
+ default="cpu",
33
+ help="Device to run on (cpu or cuda)"
34
+ )
35
+ parser.add_argument(
36
+ "--half",
37
+ action="store_true",
38
+ help="Use half precision"
39
+ )
40
+ parser.add_argument(
41
+ "--compile",
42
+ action="store_true",
43
+ default=True,
44
+ help="Compile the model for optimized inference"
45
+ )
46
+ parser.add_argument(
47
+ "--max-gradio-length",
48
+ type=int,
49
+ default=0,
50
+ help="Maximum length for Gradio input"
51
+ )
52
+ parser.add_argument(
53
+ "--theme",
54
+ type=str,
55
+ default="light",
56
+ help="Theme for the Gradio app"
57
  )
 
 
 
 
 
 
 
58
  return parser.parse_args()
59
 
60
 
61
+ def main():
62
  args = parse_args()
63
+
64
  args.precision = torch.half if args.half else torch.bfloat16
65
 
66
  logger.info("Loading Llama model...")
 
80
 
81
  logger.info("Decoder model loaded, warming up...")
82
 
83
+ # Perform a dry run to warm up the model
84
+ inference(
85
+ text="Hello, world!",
86
+ enable_reference_audio=False,
87
+ reference_audio=None,
88
+ reference_text="",
89
+ max_new_tokens=0,
90
+ chunk_length=100,
91
+ top_p=0.7,
92
+ repetition_penalty=1.2,
93
+ temperature=0.7,
 
 
94
  )
95
 
96
  logger.info("Warming up done, launching the web UI...")
97
 
98
+ # Launch the Gradio app
99
  app = build_app()
100
  app.launch(show_api=True)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()