Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from argparse import ArgumentParser | |
from loguru import logger | |
from tools.llama.generate import launch_thread_safe_queue | |
from tools.vqgan.inference import load_model as load_decoder_model | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--llama-checkpoint-path", | |
type=str, | |
default="checkpoints/fish-speech-1.4-sft-yth-lora", | |
help="Path to the Llama checkpoint" | |
) | |
parser.add_argument( | |
"--decoder-checkpoint-path", | |
type=str, | |
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", | |
help="Path to the VQ-GAN checkpoint" | |
) | |
parser.add_argument( | |
"--decoder-config-name", | |
type=str, | |
default="firefly_gan_vq", | |
help="VQ-GAN config name" | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cpu", | |
help="Device to run on (cpu or cuda)" | |
) | |
parser.add_argument( | |
"--half", | |
action="store_true", | |
help="Use half precision" | |
) | |
parser.add_argument( | |
"--compile", | |
action="store_true", | |
default=True, | |
help="Compile the model for optimized inference" | |
) | |
parser.add_argument( | |
"--max-gradio-length", | |
type=int, | |
default=0, | |
help="Maximum length for Gradio input" | |
) | |
parser.add_argument( | |
"--theme", | |
type=str, | |
default="light", | |
help="Theme for the Gradio app" | |
) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
args.precision = torch.half if args.half else torch.bfloat16 | |
logger.info("Loading Llama model...") | |
llama_queue = launch_thread_safe_queue( | |
checkpoint_path=args.llama_checkpoint_path, | |
device=args.device, | |
precision=args.precision, | |
compile=args.compile, | |
) | |
logger.info("Llama model loaded, loading VQ-GAN model...") | |
decoder_model = load_decoder_model( | |
config_name=args.decoder_config_name, | |
checkpoint_path=args.decoder_checkpoint_path, | |
device=args.device, | |
) | |
logger.info("Decoder model loaded, warming up...") | |
# Perform a dry run to warm up the model | |
inference( | |
text="Hello, world!", | |
enable_reference_audio=False, | |
reference_audio=None, | |
reference_text="", | |
max_new_tokens=0, | |
chunk_length=100, | |
top_p=0.7, | |
repetition_penalty=1.2, | |
temperature=0.7, | |
) | |
logger.info("Warming up done, launching the web UI...") | |
# Launch the Gradio app | |
app = build_app() | |
app.launch(show_api=True) | |
if __name__ == "__main__": | |
main() | |