import spaces import gradio as gr import torch import nltk import numpy as np from PIL import Image, ImageDraw from diffusers import DDIMScheduler from diffusers.models.attention_processor import AttnProcessor2_0 from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline from injection_utils import register_attention_editor_diffusers from bounded_attention import BoundedAttention from pytorch_lightning import seed_everything REMOTE_MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0" LOCAL_MODEL_PATH = "./model" RESOLUTION = 256 MIN_SIZE = 0.01 WHITE = 255 COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"] PROMPT1 = "a ginger kitten and a gray puppy in a yard" PROMPT2 = "3 D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest" PROMPT3 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship" PROMPT4 = "a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter" PROMPT5 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool" EXAMPLE_BOXES = { PROMPT1: [ [0.15, 0.2, 0.45, 0.9], [0.55, 0.25, 0.85, 0.95], ], PROMPT2 : [ [0.35, 0.4, 0.65, 0.9], [0, 0.6, 0.3, 0.9], [0.7, 0.55, 1, 0.85] ], PROMPT3: [ [0.4, 0.45, 0.6, 0.95], [0.2, 0.3, 0.4, 0.85], [0.6, 0.3, 0.8, 0.85], [0.1, 0, 0.9, 0.3] ], PROMPT4: [ [0.05, 0.5, 0.45, 0.85], [0.55, 0.6, 0.95, 0.85], [0.3, 0.2, 0.7, 0.45], ], PROMPT5: [ [0, 0.5, 0.2, 0.8], [0.2, 0.2, 0.4, 0.5], [0.4, 0.5, 0.6, 0.8], [0.6, 0.2, 0.8, 0.5], [0.8, 0.5, 1, 0.8] ], } CSS = """ #paper-info a { color:#008AD7; text-decoration: none; } #paper-info a:hover { cursor: pointer; text-decoration: none; } .tooltip { color: #555; position: relative; display: inline-block; cursor: pointer; } .tooltip .tooltiptext { visibility: hidden; width: 400px; background-color: #555; color: #fff; text-align: center; padding: 5px; border-radius: 5px; position: absolute; z-index: 1; /* Set z-index to 1 */ left: 10px; top: 100%; opacity: 0; transition: opacity 0.3s; } .tooltip:hover .tooltiptext { visibility: visible; opacity: 1; z-index: 9999; /* Set a high z-index value when hovering */ } """ DESCRIPTION = """

Bounded Attention
[Project Page] [Paper] [GitHub]

""" COPY_LINK = """ Duplicate Space Duplicate this space to generate more samples without waiting in queue. To get better results, use our code on your own GPU and increase the number of guidance steps to 15. """ ADVANCED_OPTION_DESCRIPTION = """
Number of guidance steps ⓘ The number of timesteps in which to perform guidance. Recommended value is 15, but increasing this will also increases the runtime.
Batch size ⓘ The number of images to generate.
Initial step size ⓘ The initial step size of the linear step size scheduler when performing guidance.
Final step size ⓘ The final step size of the linear step size scheduler when performing guidance.
First refinement step ⓘ The timestep from which subject mask refinement is performed.
Number of self-attention clusters per subject ⓘ The number of clusters computed when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.
Cross-attention loss scale factor ⓘ The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.
Self-attention loss scale factor ⓘ The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.
Number of Gradient Descent iterations per timestep ⓘ The number of Gradient Descent iterations for each timestep when performing guidance.
Loss Threshold ⓘ If the loss is below the threshold, Gradient Descent stops for that timestep.
Classifier-free guidance scale ⓘ The scale factor of classifier-free guidance.
""" FOOTNOTE = """

The source code of this demo is based on the GLIGEN demo.

""" def inference( boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed, ): if not torch.cuda.is_available(): raise gr.Error("cuda is not available") device = torch.device("cuda") scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) model = StableDiffusionXLPipeline.from_pretrained(LOCAL_MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16, device_map="auto") model.to(device) model.unet.set_attn_processor(AttnProcessor2_0()) model.enable_sequential_cpu_offload() seed_everything(seed) start_code = torch.randn([len(prompts), 4, 128, 128], device=device) eos_token_index = None if num_tokens is None else num_tokens + 1 editor = BoundedAttention( boxes, prompts, subject_token_indices, list(range(70, 82)), list(range(70, 82)), filter_token_indices=filter_token_indices, eos_token_index=eos_token_index, cross_loss_coef=cross_loss_scale, self_loss_coef=self_loss_scale, max_guidance_iter=num_guidance_steps, max_guidance_iter_per_step=num_iterations, start_step_size=init_step_size, end_step_size=final_step_size, loss_stopping_value=loss_threshold, min_clustering_step=first_refinement_step, num_clusters_per_box=num_clusters_per_subject, max_resolution=32, ) register_attention_editor_diffusers(model, editor) return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images @spaces.GPU(duration=340) def generate( prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes, ): print('boxes in generate', boxes) subject_token_indices = convert_token_indices(subject_token_indices, nested=True) if len(boxes) != len(subject_token_indices): raise gr.Error(""" The number of boxes should be equal to the number of subjects. Number of boxes drawn: {}, number of subjects: {}. """.format(len(boxes), len(subject_token_indices))) filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None prompts = [prompt.strip(".").strip(",").strip()] * batch_size images = inference( boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed) return images def convert_token_indices(token_indices, nested=False): if nested: return [convert_token_indices(indices, nested=False) for indices in token_indices.split(";")] return [int(index.strip()) for index in token_indices.split(",") if len(index.strip()) > 0] def draw(sketchpad): boxes = [] for i, layer in enumerate(sketchpad["layers"]): non_zeros = layer.nonzero() x1 = x2 = y1 = y2 = 0 if len(non_zeros[0]) > 0: x1x2 = non_zeros[1] / layer.shape[1] y1y2 = non_zeros[0] / layer.shape[0] x1 = x1x2.min() x2 = x1x2.max() y1 = y1y2.min() y2 = y1y2.max() if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE): raise gr.Error(f"Box in layer {i} is too small") boxes.append((x1, y1, x2, y2)) layout_image = draw_boxes(boxes) return [boxes, layout_image] def draw_boxes(boxes, is_sketch=False): if len(boxes) == 0: return None boxes = np.array(boxes) * RESOLUTION image = Image.new("RGB", (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE)) drawing = ImageDraw.Draw(image) for i, box in enumerate(boxes.astype(int).tolist()): color = "black" if is_sketch else COLORS[i % len(COLORS)] drawing.rectangle(box, outline=color, width=4) return image def clear(batch_size): return [[], None, None, None] def build_example_layout(prompt, *args): boxes = EXAMPLE_BOXES[prompt] composite = draw_boxes(boxes, is_sketch=True) sketchpad = {"background": None, "layers": [], "composite": composite} layout_image = draw_boxes(boxes) return boxes, sketchpad, layout_image def main(): nltk.download("averaged_perceptron_tagger") model = StableDiffusionXLPipeline.from_pretrained(REMOTE_MODEL_PATH) model.save_pretrained(LOCAL_MODEL_PATH) del model with gr.Blocks( css=CSS, title="Bounded Attention demo", ) as demo: gr.HTML(DESCRIPTION) gr.HTML(COPY_LINK) with gr.Column(): gr.HTML("Scroll down to see examples of the required input format.") prompt = gr.Textbox( label="Text prompt", ) subject_token_indices = gr.Textbox( label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)", ) filter_token_indices = gr.Textbox( label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)", ) num_tokens = gr.Textbox( label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)", ) with gr.Row(): sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)") layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False) with gr.Row(): clear_button = gr.Button(value="Clear") generate_layout_button = gr.Button(value="Generate layout") generate_image_button = gr.Button(value="Generate image") with gr.Row(): out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False) with gr.Accordion("Advanced Options", open=False): with gr.Column(): gr.HTML(ADVANCED_OPTION_DESCRIPTION) batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)") num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance") init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size") final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size") first_refinement_step = gr.Slider(minimum=0, maximum=50, step=1, value=15, label="The timestep from which to start refining the subject masks") num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject") cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor") self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor") num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations") loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold") classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale") seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed") boxes = gr.State([]) clear_button.click( clear, inputs=[batch_size], outputs=[boxes, sketchpad, layout_image, out_images], queue=False, ) generate_layout_button.click( draw, inputs=[sketchpad], outputs=[boxes, layout_image], queue=False, ) generate_image_button.click( fn=generate, inputs=[ prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes, ], outputs=[out_images], queue=True, ) with gr.Column(): gr.Examples( examples=[ [ PROMPT1, "2,3;6,7", "1,4,5,8,9", "10", 15, 10, 3, 1, 1, 7.5, 1, 5, 0.2, 8, 12, ], [ PROMPT2, "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21", 25, 18, 3, 1, 1, 7.5, 1, 5, 0.2, 8, 286, ], [ PROMPT3, "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17", 18, 12, 3, 1, 1, 7.5, 1, 5, 0.2, 8, 216, ], [ PROMPT4, "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17", 25, 18, 3, 1, 1, 7.5, 1, 5, 0.2, 8, 86, ], [ PROMPT5, "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22", 18, 12, 3, 1, 1, 7.5, 1, 5, 0.2, 8, 152, ], ], fn=build_example_layout, inputs=[ prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, ], outputs=[boxes, sketchpad, layout_image], run_on_click=True, ) gr.HTML(FOOTNOTE) demo.launch(show_api=False, show_error=True) if __name__ == "__main__": main()