import spaces import gradio as gr import torch import nltk import numpy as np from PIL import Image, ImageDraw from diffusers import DDIMScheduler from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline from injection_utils import regiter_attention_editor_diffusers from bounded_attention import BoundedAttention from pytorch_lightning import seed_everything from functools import partial RESOLUTION = 256 MIN_SIZE = 0.01 WHITE = 255 COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"] PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest" PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship" PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool" EXAMPLE_BOXES = { PROMPT1 : [ [0.35, 0.4, 0.65, 0.9], [0, 0.6, 0.3, 0.9], [0.7, 0.55, 1, 0.85] ], PROMPT2: [ [0.4, 0.45, 0.6, 0.95], [0.2, 0.3, 0.4, 0.85], [0.6, 0.3, 0.8, 0.85], [0.1, 0, 0.9, 0.3] ], PROMPT3: [ [0, 0.5, 0.2, 0.8], [0.2, 0.2, 0.4, 0.5], [0.4, 0.5, 0.6, 0.8], [0.6, 0.2, 0.8, 0.5], [0.8, 0.5, 1, 0.8] ], } def inference( boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed, ): if not torch.cuda.is_available(): raise gr.Error("cuda is not available") device = torch.device("cuda") model_path = "stabilityai/stable-diffusion-xl-base-1.0" scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16).to(device) model.unet.set_default_attn_processor() model.enable_sequential_cpu_offload() seed_everything(seed) start_code = torch.randn([len(prompts), 4, 128, 128], device=device) eos_token_index = num_tokens + 1 editor = BoundedAttention( boxes, prompts, subject_token_indices, list(range(70, 82)), list(range(70, 82)), filter_token_indices=filter_token_indices, eos_token_index=eos_token_index, cross_loss_coef=cross_loss_scale, self_loss_coef=self_loss_scale, max_guidance_iter=num_guidance_steps, max_guidance_iter_per_step=num_iterations, start_step_size=init_step_size, end_step_size=final_step_size, loss_stopping_value=loss_threshold, num_clusters_per_box=num_clusters_per_subject, ) regiter_attention_editor_diffusers(model, editor) return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images @spaces.GPU(duration=500) def generate( prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes, ): subject_token_indices = convert_token_indices(subject_token_indices, nested=True) if len(boxes) != len(subject_token_indices): raise gr.Error(""" The number of boxes should be equal to the number of subjects. Number of boxes drawn: {}, number of subjects: {}. """.format(len(boxes), len(subject_token_indices))) filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None prompts = [prompt.strip('.').strip(',').strip()] * batch_size images = inference( boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed) return images def convert_token_indices(token_indices, nested=False): if nested: return [convert_token_indices(indices, nested=False) for indices in token_indices.split(';')] return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0] def draw(sketchpad): boxes = [] for i, layer in enumerate(sketchpad['layers']): non_zeros = layer.nonzero() x1 = x2 = y1 = y2 = 0 if len(non_zeros[0]) > 0: x1x2 = non_zeros[1] / layer.shape[1] y1y2 = non_zeros[0] / layer.shape[0] x1 = x1x2.min() x2 = x1x2.max() y1 = y1y2.min() y2 = y1y2.max() if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE): raise gr.Error(f'Box in layer {i} is too small') boxes.append((x1, y1, x2, y2)) layout_image = draw_boxes(boxes) return [boxes, layout_image] def draw_boxes(boxes): if len(boxes) == 0: return None boxes = np.array(boxes) * RESOLUTION image = Image.new('RGB', (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE)) drawing = ImageDraw.Draw(image) for i, box in enumerate(boxes.astype(int).tolist()): drawing.rectangle(box, outline=COLORS[i % len(COLORS)], width=4) return image def clear(batch_size): return [[], None, None, None] def generate_example( prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, ): layers = [] boxes = EXAMPLE_BOXES[prompt] for box in boxes: layers.append(draw_boxes([box])) sketchpad = {'layers': layers} layout_images = draw_boxes(boxes) out_images = generate(prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes) return boxes, sketchpad, layout_image, out_images def main(): css = """ #paper-info a { color:#008AD7; text-decoration: none; } #paper-info a:hover { cursor: pointer; text-decoration: none; } .tooltip { color: #555; position: relative; display: inline-block; cursor: pointer; } .tooltip .tooltiptext { visibility: hidden; width: 400px; background-color: #555; color: #fff; text-align: center; padding: 5px; border-radius: 5px; position: absolute; z-index: 1; /* Set z-index to 1 */ left: 10px; top: 100%; opacity: 0; transition: opacity 0.3s; } .tooltip:hover .tooltiptext { visibility: visible; opacity: 1; z-index: 9999; /* Set a high z-index value when hovering */ } """ nltk.download('averaged_perceptron_tagger') with gr.Blocks( css=css, title="Bounded Attention demo", ) as demo: description = """

Bounded Attention
[Project Page] [Paper] [GitHub]

""" gr.HTML(description) with gr.Column(): prompt = gr.Textbox( label="Text prompt", ) subject_token_indices = gr.Textbox( label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)", ) filter_token_indices = gr.Textbox( label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)", ) num_tokens = gr.Textbox( label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)", ) with gr.Row(): sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION) layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1) with gr.Row(): clear_button = gr.Button(value='Clear') generate_layout_button = gr.Button(value='Generate layout') generate_image_button = gr.Button(value='Generate image') with gr.Row(): out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False) with gr.Accordion("Advanced Options", open=False): with gr.Column(): description = """
Batch size ⓘ The number of images to generate.
Initial step size ⓘ The initial step size of the linear step size scheduler when performing guidance.
Final step size ⓘ The final step size of the linear step size scheduler when performing guidance.
Number of self-attention clusters per subject ⓘ Determines the number of clusters when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.
Cross-attention loss scale factor ⓘ The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.
Self-attention loss scale factor ⓘ The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.
Classifier-free guidance scale ⓘ The scale factor of classifier-free guidance.
Number of Gradient Descent iterations per timestep ⓘ The number of Gradient Descent iterations for each timestep when performing guidance.
Loss Threshold ⓘ If the loss is below the threshold, Gradient Descent stops for that timestep.
Number of guidance steps ⓘ The number of timesteps in which to perform guidance.
""" gr.HTML(description) batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)") init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=18, label="Initial step size") final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=5, label="Final step size") num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject") cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor") self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor") classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale") num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations") loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold") num_guidance_steps = gr.Slider(minimum=10, maximum=20, step=1, value=15, label="Number of timesteps to perform guidance") seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed") boxes = gr.State([]) clear_button.click( clear, inputs=[batch_size], outputs=[boxes, sketchpad, layout_image, out_images], queue=False, ) generate_layout_button.click( draw, inputs=[sketchpad], outputs=[boxes, layout_image], queue=False, ) generate_image_button.click( fn=generate, inputs=[ prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes, ], outputs=[out_images], queue=True, ) with gr.Column(): gr.Examples( examples=[ [ "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest", "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21", 25, 10, 3, 1, 1, 7.5, 1, 5, 0.2, 15, 286, ], [ "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship", "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17", 18, 5, 3, 1, 1, 7.5, 1, 5, 0.2, 15, 216, ], [ "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool", "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22", 18, 5, 3, 1, 1, 7.5, 1, 5, 0.2, 15, 156, ], ], fn=generate_example, inputs=[ prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, ], outputs=[boxes, sketchpad, layout_image, out_images], cache_examples=True, ) description = """

The source code of this demo is based on the GLIGEN demo.

""" gr.HTML(description) demo.launch(show_api=False, show_error=True) if __name__ == '__main__': main()