omer11a commited on
Commit
f42e996
·
1 Parent(s): 49a7542

Fixed errors

Browse files
Files changed (1) hide show
  1. app.py +148 -145
app.py CHANGED
@@ -13,11 +13,51 @@ from pytorch_lightning import seed_everything
13
 
14
  from functools import partial
15
 
 
16
  RESOLUTION = 256
17
  MIN_SIZE = 0.01
18
  WHITE = 255
19
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  DESCRIPTION = """
22
  <p style="text-align: center; font-weight: bold;">
23
  <span style="font-size: 28px">Bounded Attention</span>
@@ -72,14 +112,8 @@ FOOTNOTE = """
72
  """
73
 
74
 
75
- MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
76
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
77
- model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16)
78
- model.unet.set_default_attn_processor()
79
- model.enable_sequential_cpu_offload()
80
-
81
-
82
  def inference(
 
83
  boxes,
84
  prompts,
85
  subject_token_indices,
@@ -125,14 +159,15 @@ def inference(
125
  )
126
 
127
  register_attention_editor_diffusers(model, editor)
128
-
129
  images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
130
  unregister_attention_editor_diffusers(model)
131
  model.to(torch.device("cpu"))
 
132
 
133
 
134
  @spaces.GPU(duration=300)
135
  def generate(
 
136
  prompt,
137
  subject_token_indices,
138
  filter_token_indices,
@@ -162,7 +197,7 @@ def generate(
162
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
163
 
164
  images = inference(
165
- boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
166
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
167
  num_iterations, loss_threshold, num_guidance_steps, seed)
168
 
@@ -217,139 +252,107 @@ def clear(batch_size):
217
 
218
 
219
  def main():
220
- css = """
221
- #paper-info a {
222
- color:#008AD7;
223
- text-decoration: none;
224
- }
225
- #paper-info a:hover {
226
- cursor: pointer;
227
- text-decoration: none;
228
- }
229
-
230
- .tooltip {
231
- color: #555;
232
- position: relative;
233
- display: inline-block;
234
- cursor: pointer;
235
- }
236
-
237
- .tooltip .tooltiptext {
238
- visibility: hidden;
239
- width: 400px;
240
- background-color: #555;
241
- color: #fff;
242
- text-align: center;
243
- padding: 5px;
244
- border-radius: 5px;
245
- position: absolute;
246
- z-index: 1; /* Set z-index to 1 */
247
- left: 10px;
248
- top: 100%;
249
- opacity: 0;
250
- transition: opacity 0.3s;
251
- }
252
-
253
- .tooltip:hover .tooltiptext {
254
- visibility: visible;
255
- opacity: 1;
256
- z-index: 9999; /* Set a high z-index value when hovering */
257
- }
258
- """
259
-
260
- nltk.download("averaged_perceptron_tagger")
261
-
262
- with gr.Blocks(
263
- css=css,
264
- title="Bounded Attention demo",
265
- ) as demo:
266
- gr.HTML(DESCRIPTION)
267
- gr.HTML(COPY_LINK)
268
-
269
- with gr.Column():
270
- gr.HTML("Scroll down to see examples of the required input format.")
271
-
272
- prompt = gr.Textbox(
273
- label="Text prompt",
274
- )
275
-
276
- subject_token_indices = gr.Textbox(
277
- label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
278
- )
279
-
280
- filter_token_indices = gr.Textbox(
281
- label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
282
- )
283
-
284
- num_tokens = gr.Textbox(
285
- label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
286
- )
287
-
288
- with gr.Row():
289
- sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)")
290
- layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False)
291
-
292
- with gr.Row():
293
- clear_button = gr.Button(value="Clear")
294
- generate_layout_button = gr.Button(value="Generate layout")
295
- generate_image_button = gr.Button(value="Generate image")
296
-
297
- with gr.Row():
298
- out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
299
-
300
- with gr.Accordion("Advanced Options", open=False):
301
- with gr.Column():
302
- gr.HTML(ADVANCED_OPTION_DESCRIPTION)
303
- batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
304
- num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
305
- init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
306
- final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
307
- num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
308
- cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
309
- self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
310
- num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations")
311
- loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold")
312
- classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale")
313
- seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
314
-
315
- boxes = gr.State([])
316
-
317
- clear_button.click(
318
- clear,
319
- inputs=[batch_size],
320
- outputs=[boxes, sketchpad, layout_image, out_images],
321
- queue=False,
322
- )
323
-
324
- generate_layout_button.click(
325
- draw,
326
- inputs=[sketchpad],
327
- outputs=[boxes, layout_image],
328
- queue=False,
329
- )
330
-
331
- generate_image_button.click(
332
- fn=generate,
333
- inputs=[
334
- prompt, subject_token_indices, filter_token_indices, num_tokens,
335
- init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
336
- classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
337
- seed,
338
- boxes,
339
- ],
340
- outputs=[out_images],
341
- queue=True,
342
- )
343
-
344
- with gr.Column():
345
- gr.Examples(
346
- examples=[
347
- ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
348
- ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
349
- ],
350
- inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
351
- )
352
-
353
- gr.HTML(FOOTNOTE)
354
-
355
- demo.launch(show_api=False, show_error=True)
 
13
 
14
  from functools import partial
15
 
16
+ MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
17
  RESOLUTION = 256
18
  MIN_SIZE = 0.01
19
  WHITE = 255
20
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
21
 
22
+ CSS = """
23
+ #paper-info a {
24
+ color:#008AD7;
25
+ text-decoration: none;
26
+ }
27
+ #paper-info a:hover {
28
+ cursor: pointer;
29
+ text-decoration: none;
30
+ }
31
+
32
+ .tooltip {
33
+ color: #555;
34
+ position: relative;
35
+ display: inline-block;
36
+ cursor: pointer;
37
+ }
38
+
39
+ .tooltip .tooltiptext {
40
+ visibility: hidden;
41
+ width: 400px;
42
+ background-color: #555;
43
+ color: #fff;
44
+ text-align: center;
45
+ padding: 5px;
46
+ border-radius: 5px;
47
+ position: absolute;
48
+ z-index: 1; /* Set z-index to 1 */
49
+ left: 10px;
50
+ top: 100%;
51
+ opacity: 0;
52
+ transition: opacity 0.3s;
53
+ }
54
+
55
+ .tooltip:hover .tooltiptext {
56
+ visibility: visible;
57
+ opacity: 1;
58
+ z-index: 9999; /* Set a high z-index value when hovering */
59
+ }
60
+ """
61
  DESCRIPTION = """
62
  <p style="text-align: center; font-weight: bold;">
63
  <span style="font-size: 28px">Bounded Attention</span>
 
112
  """
113
 
114
 
 
 
 
 
 
 
 
115
  def inference(
116
+ model,
117
  boxes,
118
  prompts,
119
  subject_token_indices,
 
159
  )
160
 
161
  register_attention_editor_diffusers(model, editor)
 
162
  images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
163
  unregister_attention_editor_diffusers(model)
164
  model.to(torch.device("cpu"))
165
+ return images
166
 
167
 
168
  @spaces.GPU(duration=300)
169
  def generate(
170
+ model,
171
  prompt,
172
  subject_token_indices,
173
  filter_token_indices,
 
197
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
198
 
199
  images = inference(
200
+ model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
201
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
202
  num_iterations, loss_threshold, num_guidance_steps, seed)
203
 
 
252
 
253
 
254
  def main():
255
+ nltk.download("averaged_perceptron_tagger")
256
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
257
+ model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16)
258
+ model.unet.set_default_attn_processor()
259
+ model.enable_sequential_cpu_offload()
260
+
261
+ with gr.Blocks(
262
+ css=CSS,
263
+ title="Bounded Attention demo",
264
+ ) as demo:
265
+ gr.HTML(DESCRIPTION)
266
+ gr.HTML(COPY_LINK)
267
+
268
+ with gr.Column():
269
+ gr.HTML("Scroll down to see examples of the required input format.")
270
+
271
+ prompt = gr.Textbox(
272
+ label="Text prompt",
273
+ )
274
+
275
+ subject_token_indices = gr.Textbox(
276
+ label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
277
+ )
278
+
279
+ filter_token_indices = gr.Textbox(
280
+ label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
281
+ )
282
+
283
+ num_tokens = gr.Textbox(
284
+ label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
285
+ )
286
+
287
+ with gr.Row():
288
+ sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)")
289
+ layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False)
290
+
291
+ with gr.Row():
292
+ clear_button = gr.Button(value="Clear")
293
+ generate_layout_button = gr.Button(value="Generate layout")
294
+ generate_image_button = gr.Button(value="Generate image")
295
+
296
+ with gr.Row():
297
+ out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
298
+
299
+ with gr.Accordion("Advanced Options", open=False):
300
+ with gr.Column():
301
+ gr.HTML(ADVANCED_OPTION_DESCRIPTION)
302
+ batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
303
+ num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
304
+ init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
305
+ final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
306
+ num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
307
+ cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
308
+ self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
309
+ num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations")
310
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold")
311
+ classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale")
312
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
313
+
314
+ boxes = gr.State([])
315
+
316
+ clear_button.click(
317
+ clear,
318
+ inputs=[batch_size],
319
+ outputs=[boxes, sketchpad, layout_image, out_images],
320
+ queue=False,
321
+ )
322
+
323
+ generate_layout_button.click(
324
+ draw,
325
+ inputs=[sketchpad],
326
+ outputs=[boxes, layout_image],
327
+ queue=False,
328
+ )
329
+
330
+ generate_image_button.click(
331
+ fn=partial(generate, model),
332
+ inputs=[
333
+ prompt, subject_token_indices, filter_token_indices, num_tokens,
334
+ init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
335
+ classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
336
+ seed,
337
+ boxes,
338
+ ],
339
+ outputs=[out_images],
340
+ queue=True,
341
+ )
342
+
343
+ with gr.Column():
344
+ gr.Examples(
345
+ examples=[
346
+ ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
347
+ ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
348
+ ],
349
+ inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
350
+ )
351
+
352
+ gr.HTML(FOOTNOTE)
353
+
354
+ demo.launch(show_api=False, show_error=True)
355
+
356
+
357
+ if name == "__main__":
358
+ main()