File size: 2,680 Bytes
49d537b
 
77e8f11
49d537b
 
 
 
 
 
 
 
 
77e8f11
 
 
49d537b
 
 
77e8f11
49d537b
77e8f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d537b
 
 
 
77e8f11
49d537b
77e8f11
49d537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e8f11
 
 
 
 
 
 
 
 
 
 
49d537b
 
 
 
77e8f11
49d537b
 
77e8f11
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
from argparse import ArgumentParser
from loguru import logger
from tools.llama.generate import launch_thread_safe_queue
from tools.vqgan.inference import load_model as load_decoder_model


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--llama-checkpoint-path",
        type=str,
        default="checkpoints/fish-speech-1.4-sft-yth-lora",
        help="Path to the Llama checkpoint"
    )
    parser.add_argument(
        "--decoder-checkpoint-path",
        type=str,
        default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
        help="Path to the VQ-GAN checkpoint"
    )
    parser.add_argument(
        "--decoder-config-name",
        type=str,
        default="firefly_gan_vq",
        help="VQ-GAN config name"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device to run on (cpu or cuda)"
    )
    parser.add_argument(
        "--half", 
        action="store_true", 
        help="Use half precision"
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=True,
        help="Compile the model for optimized inference"
    )
    parser.add_argument(
        "--max-gradio-length",
        type=int,
        default=0,
        help="Maximum length for Gradio input"
    )
    parser.add_argument(
        "--theme",
        type=str,
        default="light",
        help="Theme for the Gradio app"
    )
    return parser.parse_args()


def main():
    args = parse_args()

    args.precision = torch.half if args.half else torch.bfloat16

    logger.info("Loading Llama model...")
    llama_queue = launch_thread_safe_queue(
        checkpoint_path=args.llama_checkpoint_path,
        device=args.device,
        precision=args.precision,
        compile=args.compile,
    )
    logger.info("Llama model loaded, loading VQ-GAN model...")

    decoder_model = load_decoder_model(
        config_name=args.decoder_config_name,
        checkpoint_path=args.decoder_checkpoint_path,
        device=args.device,
    )

    logger.info("Decoder model loaded, warming up...")

    # Perform a dry run to warm up the model
    inference(
        text="Hello, world!",
        enable_reference_audio=False,
        reference_audio=None,
        reference_text="",
        max_new_tokens=0,
        chunk_length=100,
        top_p=0.7,
        repetition_penalty=1.2,
        temperature=0.7,
    )

    logger.info("Warming up done, launching the web UI...")

    # Launch the Gradio app
    app = build_app()
    app.launch(show_api=True)


if __name__ == "__main__":
    main()