Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import os | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( | |
StableDiffusionXLPipeline, | |
) | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa | |
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, | |
) | |
from tqdm import tqdm | |
from asset3d_gen.models.text_model import ( | |
build_text2img_ip_pipeline, | |
build_text2img_pipeline, | |
text2img_gen, | |
) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Text to Image.") | |
parser.add_argument( | |
"--prompts", | |
type=str, | |
nargs="+", | |
help="List of prompts (space-separated).", | |
) | |
parser.add_argument( | |
"--ref_image", | |
type=str, | |
nargs="+", | |
help="List of ref_image paths (space-separated).", | |
) | |
parser.add_argument( | |
"--output_root", | |
type=str, | |
help="Root directory for saving outputs.", | |
) | |
parser.add_argument( | |
"--guidance_scale", | |
type=float, | |
default=12.0, | |
help="Guidance scale for the diffusion model.", | |
) | |
parser.add_argument( | |
"--ref_scale", | |
type=float, | |
default=0.3, | |
help="Reference image scale for the IP adapter.", | |
) | |
parser.add_argument( | |
"--n_sample", | |
type=int, | |
default=1, | |
) | |
parser.add_argument( | |
"--resolution", | |
type=int, | |
default=1024, | |
) | |
parser.add_argument( | |
"--infer_step", | |
type=int, | |
default=50, | |
) | |
args = parser.parse_args() | |
return args | |
def entrypoint( | |
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None, | |
**kwargs, | |
) -> list[str]: | |
args = parse_args() | |
for k, v in kwargs.items(): | |
if hasattr(args, k) and v is not None: | |
setattr(args, k, v) | |
prompts = args.prompts | |
if len(prompts) == 1 and prompts[0].endswith(".txt"): | |
with open(prompts[0], "r") as f: | |
prompts = f.readlines() | |
prompts = [ | |
prompt.strip() for prompt in prompts if prompt.strip() != "" | |
] | |
os.makedirs(args.output_root, exist_ok=True) | |
ip_img_paths = args.ref_image | |
if ip_img_paths is None or len(ip_img_paths) == 0: | |
args.ref_scale = 0 | |
ip_img_paths = [None] * len(prompts) | |
elif isinstance(ip_img_paths, str): | |
ip_img_paths = [ip_img_paths] * len(prompts) | |
elif isinstance(ip_img_paths, list): | |
if len(ip_img_paths) == 1: | |
ip_img_paths = ip_img_paths * len(prompts) | |
else: | |
raise ValueError("Invalid ref_image paths.") | |
assert len(ip_img_paths) == len( | |
prompts | |
), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa | |
if pipeline is None: | |
if args.ref_scale > 0: | |
pipeline = build_text2img_ip_pipeline( | |
"weights/Kolors", | |
ref_scale=args.ref_scale, | |
) | |
else: | |
pipeline = build_text2img_pipeline("weights/Kolors") | |
for idx, (prompt, ip_img_path) in tqdm( | |
enumerate(zip(prompts, ip_img_paths)), | |
desc="Generating images", | |
total=len(prompts), | |
): | |
images = text2img_gen( | |
prompt=prompt, | |
n_sample=args.n_sample, | |
guidance_scale=args.guidance_scale, | |
pipeline=pipeline, | |
ip_image=ip_img_path, | |
image_wh=[args.resolution, args.resolution], | |
infer_step=args.infer_step, | |
) | |
save_paths = [] | |
for sub_idx, image in enumerate(images): | |
save_path = ( | |
f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png" | |
) | |
image.save(save_path) | |
save_paths.append(save_path) | |
logger.info(f"Images saved to {args.output_root}") | |
return save_paths | |
if __name__ == "__main__": | |
entrypoint() | |