import os import torch from argparse import ArgumentParser from loguru import logger from tools.llama.generate import launch_thread_safe_queue from tools.vqgan.inference import load_model as load_decoder_model def parse_args(): parser = ArgumentParser() parser.add_argument( "--llama-checkpoint-path", type=str, default="checkpoints/fish-speech-1.4-sft-yth-lora", help="Path to the Llama checkpoint" ) parser.add_argument( "--decoder-checkpoint-path", type=str, default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", help="Path to the VQ-GAN checkpoint" ) parser.add_argument( "--decoder-config-name", type=str, default="firefly_gan_vq", help="VQ-GAN config name" ) parser.add_argument( "--device", type=str, default="cpu", help="Device to run on (cpu or cuda)" ) parser.add_argument( "--half", action="store_true", help="Use half precision" ) parser.add_argument( "--compile", action="store_true", default=True, help="Compile the model for optimized inference" ) parser.add_argument( "--max-gradio-length", type=int, default=0, help="Maximum length for Gradio input" ) parser.add_argument( "--theme", type=str, default="light", help="Theme for the Gradio app" ) return parser.parse_args() def main(): args = parse_args() args.precision = torch.half if args.half else torch.bfloat16 logger.info("Loading Llama model...") llama_queue = launch_thread_safe_queue( checkpoint_path=args.llama_checkpoint_path, device=args.device, precision=args.precision, compile=args.compile, ) logger.info("Llama model loaded, loading VQ-GAN model...") decoder_model = load_decoder_model( config_name=args.decoder_config_name, checkpoint_path=args.decoder_checkpoint_path, device=args.device, ) logger.info("Decoder model loaded, warming up...") # Perform a dry run to warm up the model inference( text="Hello, world!", enable_reference_audio=False, reference_audio=None, reference_text="", max_new_tokens=0, chunk_length=100, top_p=0.7, repetition_penalty=1.2, temperature=0.7, ) logger.info("Warming up done, launching the web UI...") # Launch the Gradio app app = build_app() app.launch(show_api=True) if __name__ == "__main__": main()