"""Adapted from TODO""" import argparse import json import os import numpy as np import torch from diffusers import DiffusionPipeline, StableDiffusionPipeline from einops import rearrange from PIL import Image from pytorch_lightning import seed_everything from torchvision.transforms import ToTensor from torchvision.utils import make_grid from tqdm import tqdm, trange torch.set_grad_enabled(False) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("metadata_file", type=str, help="JSONL file containing lines of metadata for each prompt") parser.add_argument("--model", type=str, default="runwayml/stable-diffusion-v1-5", help="Huggingface model name") parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs") parser.add_argument( "--n_samples", type=int, default=4, help="number of samples", ) parser.add_argument( "--steps", type=int, default=50, help="number of ddim sampling steps", ) parser.add_argument( "--negative-prompt", type=str, nargs="?", const="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face", default=None, help="negative prompt for guidance", ) parser.add_argument( "--H", type=int, default=None, help="image height, in pixel space", ) parser.add_argument( "--W", type=int, default=None, help="image width, in pixel space", ) parser.add_argument( "--scale", type=float, default=9.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--seed", type=int, default=42, help="the seed (for reproducible sampling)", ) parser.add_argument( "--batch_size", type=int, default=1, help="how many samples can be produced simultaneously", ) parser.add_argument( "--skip_grid", action="store_true", help="skip saving grid", ) opt = parser.parse_args() return opt def main(opt): # Load prompts with open(opt.metadata_file) as fp: metadatas = [json.loads(line) for line in fp] # Load model if opt.model == "stabilityai/stable-diffusion-xl-base-1.0": model = DiffusionPipeline.from_pretrained( opt.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) model.enable_xformers_memory_efficient_attention() else: model = StableDiffusionPipeline.from_pretrained(opt.model, torch_dtype=torch.float16) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) model.enable_attention_slicing() for index, metadata in enumerate(metadatas): seed_everything(opt.seed) outpath = os.path.join(opt.outdir, f"{index:0>5}") os.makedirs(outpath, exist_ok=True) prompt = metadata["prompt"] n_rows = batch_size = opt.batch_size print(f"Prompt ({index: >3}/{len(metadatas)}): '{prompt}'") sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: json.dump(metadata, fp) sample_count = 0 with torch.no_grad(): all_samples = list() for n in trange((opt.n_samples + batch_size - 1) // batch_size, desc="Sampling"): # Generate images samples = model( prompt, height=opt.H, width=opt.W, num_inference_steps=opt.steps, guidance_scale=opt.scale, num_images_per_prompt=min(batch_size, opt.n_samples - sample_count), negative_prompt=opt.negative_prompt or None, ).images for sample in samples: sample.save(os.path.join(sample_path, f"{sample_count:05}.png")) sample_count += 1 if not opt.skip_grid: all_samples.append(torch.stack([ToTensor()(sample) for sample in samples], 0)) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid = Image.fromarray(grid.astype(np.uint8)) grid.save(os.path.join(outpath, f"grid.png")) del grid del all_samples print("Done.") if __name__ == "__main__": opt = parse_args() main(opt)