import argparse import logging import os from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( StableDiffusionXLPipeline, ) from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, ) from tqdm import tqdm from asset3d_gen.models.text_model import ( build_text2img_ip_pipeline, build_text2img_pipeline, text2img_gen, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Text to Image.") parser.add_argument( "--prompts", type=str, nargs="+", help="List of prompts (space-separated).", ) parser.add_argument( "--ref_image", type=str, nargs="+", help="List of ref_image paths (space-separated).", ) parser.add_argument( "--output_root", type=str, help="Root directory for saving outputs.", ) parser.add_argument( "--guidance_scale", type=float, default=12.0, help="Guidance scale for the diffusion model.", ) parser.add_argument( "--ref_scale", type=float, default=0.3, help="Reference image scale for the IP adapter.", ) parser.add_argument( "--n_sample", type=int, default=1, ) parser.add_argument( "--resolution", type=int, default=1024, ) parser.add_argument( "--infer_step", type=int, default=50, ) args = parser.parse_args() return args def entrypoint( pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None, **kwargs, ) -> list[str]: args = parse_args() for k, v in kwargs.items(): if hasattr(args, k) and v is not None: setattr(args, k, v) prompts = args.prompts if len(prompts) == 1 and prompts[0].endswith(".txt"): with open(prompts[0], "r") as f: prompts = f.readlines() prompts = [ prompt.strip() for prompt in prompts if prompt.strip() != "" ] os.makedirs(args.output_root, exist_ok=True) ip_img_paths = args.ref_image if ip_img_paths is None or len(ip_img_paths) == 0: args.ref_scale = 0 ip_img_paths = [None] * len(prompts) elif isinstance(ip_img_paths, str): ip_img_paths = [ip_img_paths] * len(prompts) elif isinstance(ip_img_paths, list): if len(ip_img_paths) == 1: ip_img_paths = ip_img_paths * len(prompts) else: raise ValueError("Invalid ref_image paths.") assert len(ip_img_paths) == len( prompts ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa if pipeline is None: if args.ref_scale > 0: pipeline = build_text2img_ip_pipeline( "weights/Kolors", ref_scale=args.ref_scale, ) else: pipeline = build_text2img_pipeline("weights/Kolors") for idx, (prompt, ip_img_path) in tqdm( enumerate(zip(prompts, ip_img_paths)), desc="Generating images", total=len(prompts), ): images = text2img_gen( prompt=prompt, n_sample=args.n_sample, guidance_scale=args.guidance_scale, pipeline=pipeline, ip_image=ip_img_path, image_wh=[args.resolution, args.resolution], infer_step=args.infer_step, ) save_paths = [] for sub_idx, image in enumerate(images): save_path = ( f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png" ) image.save(save_path) save_paths.append(save_path) logger.info(f"Images saved to {args.output_root}") return save_paths if __name__ == "__main__": entrypoint()