File size: 3,035 Bytes
49d537b
 
77e8f11
49d537b
 
 
9bab838
49d537b
 
 
 
 
77e8f11
 
 
49d537b
 
 
77e8f11
49d537b
77e8f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d537b
 
 
9adfb4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e8f11
49d537b
77e8f11
49d537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e8f11
 
 
 
 
 
 
 
 
 
 
49d537b
 
 
 
77e8f11
49d537b
 
77e8f11
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
from argparse import ArgumentParser
from loguru import logger
from tools.llama.generate import launch_thread_safe_queue
from tools.vqgan.inference import load_model as load_decoder_model
from tools.webui import build_app

def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--llama-checkpoint-path",
        type=str,
        default="checkpoints/fish-speech-1.4-sft-yth-lora",
        help="Path to the Llama checkpoint"
    )
    parser.add_argument(
        "--decoder-checkpoint-path",
        type=str,
        default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
        help="Path to the VQ-GAN checkpoint"
    )
    parser.add_argument(
        "--decoder-config-name",
        type=str,
        default="firefly_gan_vq",
        help="VQ-GAN config name"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device to run on (cpu or cuda)"
    )
    parser.add_argument(
        "--half", 
        action="store_true", 
        help="Use half precision"
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=True,
        help="Compile the model for optimized inference"
    )
    parser.add_argument(
        "--max-gradio-length",
        type=int,
        default=0,
        help="Maximum length for Gradio input"
    )
    parser.add_argument(
        "--theme",
        type=str,
        default="light",
        help="Theme for the Gradio app"
    )
    return parser.parse_args()

def inference(
    text,
    enable_reference_audio,
    reference_audio,
    reference_text,
    max_new_tokens,
    chunk_length,
    top_p,
    repetition_penalty,
    temperature,
):
    logger.info(f"Running inference on: {text}")
    # 模拟推理过程
    result = f"Processed text: {text}"
    return result
    
def main():
    args = parse_args()

    args.precision = torch.half if args.half else torch.bfloat16

    logger.info("Loading Llama model...")
    llama_queue = launch_thread_safe_queue(
        checkpoint_path=args.llama_checkpoint_path,
        device=args.device,
        precision=args.precision,
        compile=args.compile,
    )
    logger.info("Llama model loaded, loading VQ-GAN model...")

    decoder_model = load_decoder_model(
        config_name=args.decoder_config_name,
        checkpoint_path=args.decoder_checkpoint_path,
        device=args.device,
    )

    logger.info("Decoder model loaded, warming up...")

    # Perform a dry run to warm up the model
    inference(
        text="Hello, world!",
        enable_reference_audio=False,
        reference_audio=None,
        reference_text="",
        max_new_tokens=0,
        chunk_length=100,
        top_p=0.7,
        repetition_penalty=1.2,
        temperature=0.7,
    )

    logger.info("Warming up done, launching the web UI...")

    # Launch the Gradio app
    app = build_app()
    app.launch(show_api=True)


if __name__ == "__main__":
    main()