|
"""Adapted from TODO""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline |
|
from einops import rearrange |
|
from PIL import Image |
|
from pytorch_lightning import seed_everything |
|
from torchvision.transforms import ToTensor |
|
from torchvision.utils import make_grid |
|
from tqdm import tqdm, trange |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("metadata_file", type=str, help="JSONL file containing lines of metadata for each prompt") |
|
parser.add_argument("--model", type=str, default="runwayml/stable-diffusion-v1-5", help="Huggingface model name") |
|
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs") |
|
parser.add_argument( |
|
"--n_samples", |
|
type=int, |
|
default=4, |
|
help="number of samples", |
|
) |
|
parser.add_argument( |
|
"--steps", |
|
type=int, |
|
default=50, |
|
help="number of ddim sampling steps", |
|
) |
|
parser.add_argument( |
|
"--negative-prompt", |
|
type=str, |
|
nargs="?", |
|
const="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face", |
|
default=None, |
|
help="negative prompt for guidance", |
|
) |
|
parser.add_argument( |
|
"--H", |
|
type=int, |
|
default=None, |
|
help="image height, in pixel space", |
|
) |
|
parser.add_argument( |
|
"--W", |
|
type=int, |
|
default=None, |
|
help="image width, in pixel space", |
|
) |
|
parser.add_argument( |
|
"--scale", |
|
type=float, |
|
default=9.0, |
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=42, |
|
help="the seed (for reproducible sampling)", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=1, |
|
help="how many samples can be produced simultaneously", |
|
) |
|
parser.add_argument( |
|
"--skip_grid", |
|
action="store_true", |
|
help="skip saving grid", |
|
) |
|
opt = parser.parse_args() |
|
return opt |
|
|
|
|
|
def main(opt): |
|
|
|
with open(opt.metadata_file) as fp: |
|
metadatas = [json.loads(line) for line in fp] |
|
|
|
|
|
if opt.model == "stabilityai/stable-diffusion-xl-base-1.0": |
|
model = DiffusionPipeline.from_pretrained( |
|
opt.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" |
|
) |
|
model.enable_xformers_memory_efficient_attention() |
|
else: |
|
model = StableDiffusionPipeline.from_pretrained(opt.model, torch_dtype=torch.float16) |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model = model.to(device) |
|
model.enable_attention_slicing() |
|
|
|
for index, metadata in enumerate(metadatas): |
|
seed_everything(opt.seed) |
|
|
|
outpath = os.path.join(opt.outdir, f"{index:0>5}") |
|
os.makedirs(outpath, exist_ok=True) |
|
|
|
prompt = metadata["prompt"] |
|
n_rows = batch_size = opt.batch_size |
|
print(f"Prompt ({index: >3}/{len(metadatas)}): '{prompt}'") |
|
|
|
sample_path = os.path.join(outpath, "samples") |
|
os.makedirs(sample_path, exist_ok=True) |
|
with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: |
|
json.dump(metadata, fp) |
|
|
|
sample_count = 0 |
|
|
|
with torch.no_grad(): |
|
all_samples = list() |
|
for n in trange((opt.n_samples + batch_size - 1) // batch_size, desc="Sampling"): |
|
|
|
samples = model( |
|
prompt, |
|
height=opt.H, |
|
width=opt.W, |
|
num_inference_steps=opt.steps, |
|
guidance_scale=opt.scale, |
|
num_images_per_prompt=min(batch_size, opt.n_samples - sample_count), |
|
negative_prompt=opt.negative_prompt or None, |
|
).images |
|
for sample in samples: |
|
sample.save(os.path.join(sample_path, f"{sample_count:05}.png")) |
|
sample_count += 1 |
|
if not opt.skip_grid: |
|
all_samples.append(torch.stack([ToTensor()(sample) for sample in samples], 0)) |
|
|
|
if not opt.skip_grid: |
|
|
|
grid = torch.stack(all_samples, 0) |
|
grid = rearrange(grid, "n b c h w -> (n b) c h w") |
|
grid = make_grid(grid, nrow=n_rows) |
|
|
|
|
|
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() |
|
grid = Image.fromarray(grid.astype(np.uint8)) |
|
grid.save(os.path.join(outpath, f"grid.png")) |
|
del grid |
|
del all_samples |
|
|
|
print("Done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
opt = parse_args() |
|
main(opt) |
|
|