xinjie.wang
update
55ed985
import argparse
import logging
import os
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
StableDiffusionXLPipeline,
)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
)
from tqdm import tqdm
from asset3d_gen.models.text_model import (
build_text2img_ip_pipeline,
build_text2img_pipeline,
text2img_gen,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Text to Image.")
parser.add_argument(
"--prompts",
type=str,
nargs="+",
help="List of prompts (space-separated).",
)
parser.add_argument(
"--ref_image",
type=str,
nargs="+",
help="List of ref_image paths (space-separated).",
)
parser.add_argument(
"--output_root",
type=str,
help="Root directory for saving outputs.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=12.0,
help="Guidance scale for the diffusion model.",
)
parser.add_argument(
"--ref_scale",
type=float,
default=0.3,
help="Reference image scale for the IP adapter.",
)
parser.add_argument(
"--n_sample",
type=int,
default=1,
)
parser.add_argument(
"--resolution",
type=int,
default=1024,
)
parser.add_argument(
"--infer_step",
type=int,
default=50,
)
args = parser.parse_args()
return args
def entrypoint(
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
**kwargs,
) -> list[str]:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
prompts = args.prompts
if len(prompts) == 1 and prompts[0].endswith(".txt"):
with open(prompts[0], "r") as f:
prompts = f.readlines()
prompts = [
prompt.strip() for prompt in prompts if prompt.strip() != ""
]
os.makedirs(args.output_root, exist_ok=True)
ip_img_paths = args.ref_image
if ip_img_paths is None or len(ip_img_paths) == 0:
args.ref_scale = 0
ip_img_paths = [None] * len(prompts)
elif isinstance(ip_img_paths, str):
ip_img_paths = [ip_img_paths] * len(prompts)
elif isinstance(ip_img_paths, list):
if len(ip_img_paths) == 1:
ip_img_paths = ip_img_paths * len(prompts)
else:
raise ValueError("Invalid ref_image paths.")
assert len(ip_img_paths) == len(
prompts
), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
if pipeline is None:
if args.ref_scale > 0:
pipeline = build_text2img_ip_pipeline(
"weights/Kolors",
ref_scale=args.ref_scale,
)
else:
pipeline = build_text2img_pipeline("weights/Kolors")
for idx, (prompt, ip_img_path) in tqdm(
enumerate(zip(prompts, ip_img_paths)),
desc="Generating images",
total=len(prompts),
):
images = text2img_gen(
prompt=prompt,
n_sample=args.n_sample,
guidance_scale=args.guidance_scale,
pipeline=pipeline,
ip_image=ip_img_path,
image_wh=[args.resolution, args.resolution],
infer_step=args.infer_step,
)
save_paths = []
for sub_idx, image in enumerate(images):
save_path = (
f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
)
image.save(save_path)
save_paths.append(save_path)
logger.info(f"Images saved to {args.output_root}")
return save_paths
if __name__ == "__main__":
entrypoint()