|
import os |
|
import json |
|
import time |
|
import argparse |
|
import imageio |
|
import torch |
|
import numpy as np |
|
from torchvision import transforms |
|
|
|
from models.region_diffusion import RegionDiffusion |
|
from utils.attention_utils import get_token_maps |
|
from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\ |
|
get_attention_control_input, get_gradient_guidance_input |
|
|
|
|
|
def main(args, param): |
|
|
|
|
|
run_dir = args.run_dir |
|
os.makedirs(args.run_dir, exist_ok=True) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = RegionDiffusion(device) |
|
|
|
|
|
base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\ |
|
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json( |
|
param['text_input']) |
|
|
|
|
|
region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input( |
|
model, base_text_prompt, style_text_prompts, footnote_text_prompts, |
|
footnote_target_tokens, color_text_prompts, color_names) |
|
|
|
|
|
text_format_dict = get_attention_control_input( |
|
model, base_tokens, size_text_prompts_and_sizes) |
|
|
|
|
|
text_format_dict, color_target_token_ids = get_gradient_guidance_input( |
|
model, base_tokens, color_text_prompts, color_rgbs, text_format_dict) |
|
|
|
height = param['height'] |
|
width = param['width'] |
|
seed = param['noise_index'] |
|
negative_text = param['negative_prompt'] |
|
seed_everything(seed) |
|
|
|
|
|
begin_time = time.time() |
|
if model.attention_maps is None: |
|
model.register_evaluation_hooks() |
|
else: |
|
model.reset_attention_maps() |
|
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text], |
|
height=height, width=width, num_inference_steps=param['steps'], |
|
guidance_scale=param['guidance_weight']) |
|
fn_base = os.path.join(run_dir, 'seed%d_plain.png' % (seed)) |
|
imageio.imwrite(fn_base, plain_img[0]) |
|
print('time lapses to get attention maps: %.4f' % (time.time()-begin_time)) |
|
color_obj_masks = get_token_maps( |
|
model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed) |
|
model.masks = get_token_maps( |
|
model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens) |
|
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width), |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
antialias=True) |
|
for color_obj_mask in color_obj_masks] |
|
text_format_dict['color_obj_atten'] = color_obj_masks |
|
model.remove_evaluation_hooks() |
|
|
|
|
|
begin_time = time.time() |
|
seed_everything(seed) |
|
rich_img = model.prompt_to_img(region_text_prompts, [negative_text], |
|
height=height, width=width, num_inference_steps=param['steps'], |
|
guidance_scale=param['guidance_weight'], use_grad_guidance=use_grad_guidance, |
|
text_format_dict=text_format_dict) |
|
print('time lapses to generate image from rich text: %.4f' % |
|
(time.time()-begin_time)) |
|
fn_style = os.path.join(run_dir, 'seed%d_rich.png' % (seed)) |
|
imageio.imwrite(fn_style, rich_img[0]) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--run_dir', type=str, default='results/release/debug') |
|
parser.add_argument('--height', type=int, default=512) |
|
parser.add_argument('--width', type=int, default=512) |
|
parser.add_argument('--seed', type=int, default=6) |
|
parser.add_argument('--sample_steps', type=int, default=41) |
|
parser.add_argument('--rich_text_json', type=str, |
|
default='{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. There are palm trees in the background."}]}') |
|
parser.add_argument('--negative_prompt', type=str, default='') |
|
parser.add_argument('--guidance_weight', type=float, default=8.5) |
|
args = parser.parse_args() |
|
param = { |
|
'text_input': json.loads(args.rich_text_json), |
|
'height': args.height, |
|
'width': args.width, |
|
'guidance_weight': args.guidance_weight, |
|
'steps': args.sample_steps, |
|
'noise_index': args.seed, |
|
'negative_prompt': args.negative_prompt, |
|
} |
|
|
|
main(args, param) |
|
|