prithivMLmods commited on
Commit
d4884bc
·
verified ·
1 Parent(s): 1cbdebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -135
app.py CHANGED
@@ -11,44 +11,56 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
 
 
 
14
  config_file = hf_hub_download(
15
  "xinsir/controlnet-union-sdxl-1.0",
16
  filename="config_promax.json",
17
  )
18
-
19
  config = ControlNetModel_Union.load_config(config_file)
20
  controlnet_model = ControlNetModel_Union.from_config(config)
21
  model_file = hf_hub_download(
22
  "xinsir/controlnet-union-sdxl-1.0",
23
  filename="diffusion_pytorch_model_promax.safetensors",
24
  )
25
-
26
  sstate_dict = load_state_dict(model_file)
27
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
29
  )
30
  model.to(device="cuda", dtype=torch.float16)
31
 
 
32
  vae = AutoencoderKL.from_pretrained(
33
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
  ).to("cuda")
35
 
36
- # --- Define available pipelines ---
37
- model_mapping = {
38
- "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
39
- "RealVisXL V4.0 Lightning": "SG161222/RealVisXL_V4.0_Lightning",
40
- }
41
  pipelines = {}
42
- for name, repo in model_mapping.items():
43
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
44
- repo,
45
- torch_dtype=torch.float16,
46
- vae=vae,
47
- controlnet=model,
48
- variant="fp16",
49
- ).to("cuda")
50
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
51
- pipelines[name] = pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
54
  target_size = (width, height)
@@ -57,7 +69,7 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
57
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
58
  new_width = int(image.width * scale_factor)
59
  new_height = int(image.height * scale_factor)
60
-
61
  # Resize the source image to fit within target size
62
  source = image.resize((new_width, new_height), Image.LANCZOS)
63
 
@@ -109,6 +121,10 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
109
  elif alignment == "Bottom":
110
  margin_x = (target_size[0] - new_width) // 2
111
  margin_y = target_size[1] - new_height
 
 
 
 
112
 
113
  # Adjust margins to eliminate gaps
114
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
@@ -119,66 +135,126 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
119
  background.paste(source, (margin_x, margin_y))
120
 
121
  # Create the mask
122
- mask = Image.new('L', target_size, 255)
123
  mask_draw = ImageDraw.Draw(mask)
124
 
125
- # Calculate overlap areas
126
- white_gaps_patch = 2
127
 
128
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
129
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
130
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
131
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
132
-
133
- if alignment == "Left":
134
- left_overlap = margin_x + overlap_x if overlap_left else margin_x
135
- elif alignment == "Right":
136
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
137
- elif alignment == "Top":
138
- top_overlap = margin_y + overlap_y if overlap_top else margin_y
139
- elif alignment == "Bottom":
140
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
141
 
142
- # Draw the mask
143
- mask_draw.rectangle([
144
- (left_overlap, top_overlap),
145
- (right_overlap, bottom_overlap)
146
- ], fill=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  return background, mask
149
 
 
150
  @spaces.GPU(duration=24)
151
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
152
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
153
-
154
- cnet_image = background.copy()
155
- cnet_image.paste(0, (0, 0), mask)
156
-
157
- final_prompt = f"{prompt_input} , high quality, 4k"
158
-
159
- (
160
- prompt_embeds,
161
- negative_prompt_embeds,
162
- pooled_prompt_embeds,
163
- negative_pooled_prompt_embeds,
164
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
165
-
166
- # Generate the image
167
- for image in pipe(
168
- prompt_embeds=prompt_embeds,
169
- negative_prompt_embeds=negative_prompt_embeds,
170
- pooled_prompt_embeds=pooled_prompt_embeds,
171
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
172
- image=cnet_image,
173
- num_inference_steps=num_inference_steps
174
- ):
175
- pass # Wait for the generation to complete
176
- generated_image = image # Get the last image
177
-
178
- generated_image = generated_image.convert("RGBA")
179
- cnet_image.paste(generated_image, (0, 0), mask)
180
-
181
- return cnet_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def clear_result():
184
  """Clears the result Image."""
@@ -189,17 +265,21 @@ def preload_presets(target_ratio, ui_width, ui_height):
189
  if target_ratio == "9:16":
190
  changed_width = 720
191
  changed_height = 1280
192
- return changed_width, changed_height, gr.update()
193
  elif target_ratio == "16:9":
194
  changed_width = 1280
195
  changed_height = 720
196
- return changed_width, changed_height, gr.update()
197
  elif target_ratio == "1:1":
198
  changed_width = 1024
199
  changed_height = 1024
200
- return changed_width, changed_height, gr.update()
201
  elif target_ratio == "Custom":
 
202
  return ui_width, ui_height, gr.update(open=True)
 
 
 
203
 
204
  def select_the_right_preset(user_width, user_height):
205
  if user_width == 720 and user_height == 1280:
@@ -216,59 +296,71 @@ def toggle_custom_resize_slider(resize_option):
216
 
217
  def update_history(new_image, history):
218
  """Updates the history gallery with the new image."""
 
 
219
  if history is None:
220
  history = []
 
221
  history.insert(0, new_image)
 
 
 
 
222
  return history
223
 
224
- # --- CSS and Title (unchanged) ---
225
  css = """
226
  h1 {
227
- text-align: center;
228
- display: block;
 
 
 
 
229
  }
230
  """
231
 
232
-
233
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
 
234
  """
235
 
 
236
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
237
  with gr.Column():
238
  gr.HTML(title)
239
 
240
  with gr.Row():
241
- with gr.Column():
242
  input_image = gr.Image(
243
  type="pil",
244
  label="Input Image"
245
  )
246
 
247
- with gr.Row():
248
- model_selector = gr.Dropdown(
249
  label="Select Model",
250
  choices=list(pipelines.keys()),
251
- value="RealVisXL V5.0 Lightning",
252
  )
253
-
254
  with gr.Row():
255
  with gr.Column(scale=2):
256
- prompt_input = gr.Textbox(label="Prompt (Optional)")
257
- with gr.Column(scale=1):
258
- run_button = gr.Button("Outpaint Image")
259
 
260
  with gr.Row():
261
  target_ratio = gr.Radio(
262
- label="Expected Ratio",
263
  choices=["9:16", "16:9", "1:1", "Custom"],
264
- value="9:16",
265
  scale=2
266
  )
267
-
268
  alignment_dropdown = gr.Dropdown(
269
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
270
  value="Middle",
271
- label="Alignment"
272
  )
273
 
274
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
@@ -276,39 +368,43 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
276
  with gr.Row():
277
  width_slider = gr.Slider(
278
  label="Target Width",
279
- minimum=720,
280
  maximum=1536,
281
- step=8,
282
- value=720,
283
  )
284
  height_slider = gr.Slider(
285
  label="Target Height",
286
- minimum=720,
287
  maximum=1536,
288
- step=8,
289
- value=1280,
290
  )
291
-
292
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
 
293
  with gr.Group():
294
  overlap_percentage = gr.Slider(
295
  label="Mask overlap (%)",
 
296
  minimum=1,
297
  maximum=50,
298
- value=10,
299
  step=1
300
  )
 
301
  with gr.Row():
302
- overlap_top = gr.Checkbox(label="Overlap Top", value=True)
303
- overlap_right = gr.Checkbox(label="Overlap Right", value=True)
304
- with gr.Row():
305
- overlap_left = gr.Checkbox(label="Overlap Left", value=True)
306
- overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
307
  with gr.Row():
308
  resize_option = gr.Radio(
309
- label="Resize input image",
 
310
  choices=["Full", "50%", "33%", "25%", "Custom"],
311
- value="Full"
312
  )
313
  custom_resize_percentage = gr.Slider(
314
  label="Custom resize (%)",
@@ -316,27 +412,40 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
316
  maximum=100,
317
  step=1,
318
  value=50,
319
- visible=False
320
  )
321
-
322
  gr.Examples(
323
  examples=[
324
- ["./examples/example_1.webp", 1280, 720, "Middle"],
325
- ["./examples/example_2.jpg", 1440, 810, "Left"],
326
- ["./examples/example_3.jpg", 1024, 1024, "Top"],
327
- ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
328
  ],
329
- inputs=[input_image, width_slider, height_slider, alignment_dropdown],
 
330
  )
331
 
332
- with gr.Column():
333
  result = gr.Image(
334
  interactive=False,
335
  label="Generated Image",
336
  format="png",
337
  )
338
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
 
 
 
 
 
 
 
 
339
 
 
 
 
 
340
  target_ratio.change(
341
  fn=preload_presets,
342
  inputs=[target_ratio, width_slider, height_slider],
@@ -344,13 +453,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
344
  queue=False
345
  )
346
 
 
347
  width_slider.change(
348
  fn=select_the_right_preset,
349
  inputs=[width_slider, height_slider],
350
  outputs=[target_ratio],
351
  queue=False
352
  )
353
-
354
  height_slider.change(
355
  fn=select_the_right_preset,
356
  inputs=[width_slider, height_slider],
@@ -358,47 +467,54 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
358
  queue=False
359
  )
360
 
 
361
  resize_option.change(
362
  fn=toggle_custom_resize_slider,
363
  inputs=[resize_option],
364
  outputs=[custom_resize_percentage],
365
  queue=False
366
  )
367
-
 
 
 
 
 
 
 
 
368
  run_button.click(
369
  fn=clear_result,
370
  inputs=None,
371
- outputs=result,
 
372
  ).then(
373
  fn=infer,
374
- inputs=[input_image, width_slider, height_slider, overlap_percentage,
375
- num_inference_steps, resize_option, custom_resize_percentage,
376
- prompt_input, alignment_dropdown,
377
- overlap_left, overlap_right, overlap_top, overlap_bottom,
378
- model_selector],
379
- outputs=result,
380
  ).then(
381
- fn=lambda x, history: update_history(x, history),
382
- inputs=[result, history_gallery],
383
- outputs=history_gallery,
384
  )
385
 
 
386
  prompt_input.submit(
387
- fn=clear_result,
388
  inputs=None,
389
- outputs=result,
 
390
  ).then(
391
  fn=infer,
392
- inputs=[input_image, width_slider, height_slider, overlap_percentage,
393
- num_inference_steps, resize_option, custom_resize_percentage,
394
- prompt_input, alignment_dropdown,
395
- overlap_left, overlap_right, overlap_top, overlap_bottom,
396
- model_selector],
397
- outputs=result,
398
  ).then(
399
- fn=lambda x, history: update_history(x, history),
400
  inputs=[result, history_gallery],
401
- outputs=history_gallery,
402
  )
403
 
 
 
 
404
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
+ # --- Configuration and Model Loading ---
15
+
16
+ # Load ControlNet Union
17
  config_file = hf_hub_download(
18
  "xinsir/controlnet-union-sdxl-1.0",
19
  filename="config_promax.json",
20
  )
 
21
  config = ControlNetModel_Union.load_config(config_file)
22
  controlnet_model = ControlNetModel_Union.from_config(config)
23
  model_file = hf_hub_download(
24
  "xinsir/controlnet-union-sdxl-1.0",
25
  filename="diffusion_pytorch_model_promax.safetensors",
26
  )
 
27
  sstate_dict = load_state_dict(model_file)
28
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
  model.to(device="cuda", dtype=torch.float16)
32
 
33
+ # Load VAE
34
  vae = AutoencoderKL.from_pretrained(
35
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
36
  ).to("cuda")
37
 
38
+ # --- Load Multiple Pipelines ---
 
 
 
 
39
  pipelines = {}
40
+
41
+ # Load RealVisXL V5.0 Lightning
42
+ pipe_v5 = StableDiffusionXLFillPipeline.from_pretrained(
43
+ "SG161222/RealVisXL_V5.0_Lightning",
44
+ torch_dtype=torch.float16,
45
+ vae=vae,
46
+ controlnet=model, # Use the same controlnet
47
+ variant="fp16",
48
+ ).to("cuda")
49
+ pipe_v5.scheduler = TCDScheduler.from_config(pipe_v5.scheduler.config)
50
+ pipelines["RealVisXL V5.0 Lightning"] = pipe_v5
51
+
52
+ # Load RealVisXL V4.0 Lightning
53
+ pipe_v4 = StableDiffusionXLFillPipeline.from_pretrained(
54
+ "SG161222/RealVisXL_V4.0_Lightning",
55
+ torch_dtype=torch.float16,
56
+ vae=vae, # Use the same VAE
57
+ controlnet=model, # Use the same controlnet
58
+ variant="fp16",
59
+ ).to("cuda")
60
+ pipe_v4.scheduler = TCDScheduler.from_config(pipe_v4.scheduler.config)
61
+ pipelines["RealVisXL V4.0 Lightning"] = pipe_v4
62
+
63
+ # --- Helper Functions ---
64
 
65
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
66
  target_size = (width, height)
 
69
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
70
  new_width = int(image.width * scale_factor)
71
  new_height = int(image.height * scale_factor)
72
+
73
  # Resize the source image to fit within target size
74
  source = image.resize((new_width, new_height), Image.LANCZOS)
75
 
 
121
  elif alignment == "Bottom":
122
  margin_x = (target_size[0] - new_width) // 2
123
  margin_y = target_size[1] - new_height
124
+ else: # Default to Middle if alignment is somehow invalid
125
+ margin_x = (target_size[0] - new_width) // 2
126
+ margin_y = (target_size[1] - new_height) // 2
127
+
128
 
129
  # Adjust margins to eliminate gaps
130
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
 
135
  background.paste(source, (margin_x, margin_y))
136
 
137
  # Create the mask
138
+ mask = Image.new('L', target_size, 255) # White background (area to be filled)
139
  mask_draw = ImageDraw.Draw(mask)
140
 
141
+ # Calculate overlap areas (where the mask should be black = keep original)
142
+ white_gaps_patch = 2 # Small value to ensure no tiny gaps at edges if overlap is off
143
 
144
+ # Determine the coordinates for the black rectangle (the non-masked area)
145
+ # Start with the full area covered by the pasted image
146
+ left_black = margin_x
147
+ top_black = margin_y
148
+ right_black = margin_x + new_width
149
+ bottom_black = margin_y + new_height
150
+
151
+ # Adjust the black area based on overlap checkboxes
152
+ if overlap_left:
153
+ left_black += overlap_x
154
+ else:
155
+ # If not overlapping left, ensure the black mask starts exactly at the image edge or slightly inside
156
+ left_black += white_gaps_patch if alignment != "Left" else 0
157
 
158
+ if overlap_right:
159
+ right_black -= overlap_x
160
+ else:
161
+ # If not overlapping right, ensure the black mask ends exactly at the image edge or slightly inside
162
+ right_black -= white_gaps_patch if alignment != "Right" else 0
163
+
164
+ if overlap_top:
165
+ top_black += overlap_y
166
+ else:
167
+ # If not overlapping top, ensure the black mask starts exactly at the image edge or slightly inside
168
+ top_black += white_gaps_patch if alignment != "Top" else 0
169
+
170
+ if overlap_bottom:
171
+ bottom_black -= overlap_y
172
+ else:
173
+ # If not overlapping bottom, ensure the black mask ends exactly at the image edge or slightly inside
174
+ bottom_black -= white_gaps_patch if alignment != "Bottom" else 0
175
+
176
+ # Ensure coordinates are valid (left < right, top < bottom)
177
+ left_black = min(left_black, target_size[0])
178
+ top_black = min(top_black, target_size[1])
179
+ right_black = max(left_black, right_black) # Ensure right >= left
180
+ bottom_black = max(top_black, bottom_black) # Ensure bottom >= top
181
+ right_black = min(right_black, target_size[0])
182
+ bottom_black = min(bottom_black, target_size[1])
183
+
184
+
185
+ # Draw the black rectangle onto the white mask
186
+ # The area *inside* this rectangle will be kept (mask value 0)
187
+ # The area *outside* this rectangle will be filled (mask value 255)
188
+ if right_black > left_black and bottom_black > top_black:
189
+ mask_draw.rectangle(
190
+ [(left_black, top_black), (right_black, bottom_black)],
191
+ fill=0 # Black means keep this area
192
+ )
193
 
194
  return background, mask
195
 
196
+
197
  @spaces.GPU(duration=24)
198
+ def infer(selected_model_name, image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
199
+ if image is None:
200
+ raise gr.Error("Please upload an input image.")
201
+ try:
202
+ # Select the pipeline based on the dropdown choice
203
+ pipe = pipelines[selected_model_name]
204
+
205
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
206
+
207
+ # Create the controlnet input image (original image pasted on white bg, with masked area blacked out)
208
+ cnet_image = background.copy()
209
+ # Create a black image of the same size as the mask
210
+ black_fill = Image.new('RGB', mask.size, (0, 0, 0))
211
+ # Paste the black fill using the mask (where mask is 255/white, paste black)
212
+ cnet_image.paste(black_fill, (0, 0), mask)
213
+
214
+
215
+ final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
216
+
217
+ (
218
+ prompt_embeds,
219
+ negative_prompt_embeds,
220
+ pooled_prompt_embeds,
221
+ negative_pooled_prompt_embeds,
222
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
223
+
224
+ # Generate the image
225
+ generator = torch.Generator(device="cuda").manual_seed(np.random.randint(0, 2**32)) # Add random seed
226
+
227
+ # The pipeline expects the 'image' argument to be the background with the original content
228
+ # and the 'mask_image' argument to define the area to *inpaint* (white area in our mask)
229
+ result_image = pipe(
230
+ prompt_embeds=prompt_embeds,
231
+ negative_prompt_embeds=negative_prompt_embeds,
232
+ pooled_prompt_embeds=pooled_prompt_embeds,
233
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
234
+ image=background, # The background containing the original image
235
+ mask_image=mask, # The mask (white = fill, black = keep)
236
+ control_image=cnet_image, # ControlNet input image
237
+ num_inference_steps=num_inference_steps,
238
+ generator=generator, # Use generator for reproducibility if needed
239
+ output_type="pil" # Ensure PIL output
240
+ ).images[0]
241
+
242
+ # The pipeline directly returns the final composited image.
243
+ # No need for manual pasting like before.
244
+
245
+ return result_image
246
+ except Exception as e:
247
+ print(f"Error during inference: {e}")
248
+ import traceback
249
+ traceback.print_exc()
250
+ # Return the background image or raise a Gradio error for clarity
251
+ # raise gr.Error(f"Inference failed: {e}")
252
+ # Or return the prepared background/mask for debugging
253
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
254
+ # Combine background and mask for visualization
255
+ debug_img = Image.blend(background.convert("RGBA"), mask.convert("RGBA"), 0.5)
256
+ return debug_img # Return a debug image or None
257
+
258
 
259
  def clear_result():
260
  """Clears the result Image."""
 
265
  if target_ratio == "9:16":
266
  changed_width = 720
267
  changed_height = 1280
268
+ return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
269
  elif target_ratio == "16:9":
270
  changed_width = 1280
271
  changed_height = 720
272
+ return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
273
  elif target_ratio == "1:1":
274
  changed_width = 1024
275
  changed_height = 1024
276
+ return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
277
  elif target_ratio == "Custom":
278
+ # When switching to Custom, keep current slider values but open accordion
279
  return ui_width, ui_height, gr.update(open=True)
280
+ # Should not happen, but return current values if it does
281
+ return ui_width, ui_height, gr.update()
282
+
283
 
284
  def select_the_right_preset(user_width, user_height):
285
  if user_width == 720 and user_height == 1280:
 
296
 
297
  def update_history(new_image, history):
298
  """Updates the history gallery with the new image."""
299
+ if new_image is None: # Don't add None to history (e.g., on clear or error)
300
+ return history
301
  if history is None:
302
  history = []
303
+ # Prepend the new image (as PIL or path depending on Gallery config)
304
  history.insert(0, new_image)
305
+ # Limit history size if desired (e.g., keep last 12)
306
+ max_history = 12
307
+ if len(history) > max_history:
308
+ history = history[:max_history]
309
  return history
310
 
311
+ # --- CSS and Title ---
312
  css = """
313
  h1 {
314
+ text-align: center;
315
+ display: block;
316
+ }
317
+ .gradio-container {
318
+ max-width: 1280px !important;
319
+ margin: auto !important;
320
  }
321
  """
322
 
 
323
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
324
+ <p align="center">Expand images using ControlNet Union and Lightning models. Choose a base model below.</p>
325
  """
326
 
327
+ # --- Gradio UI ---
328
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
329
  with gr.Column():
330
  gr.HTML(title)
331
 
332
  with gr.Row():
333
+ with gr.Column(scale=2): # Input column
334
  input_image = gr.Image(
335
  type="pil",
336
  label="Input Image"
337
  )
338
 
339
+ # --- Model Selector ---
340
+ model_selector = gr.Dropdown(
341
  label="Select Model",
342
  choices=list(pipelines.keys()),
343
+ value="RealVisXL V5.0 Lightning", # Default model
344
  )
345
+
346
  with gr.Row():
347
  with gr.Column(scale=2):
348
+ prompt_input = gr.Textbox(label="Prompt (Describe the desired output)", placeholder="e.g., beautiful landscape, photorealistic")
349
+ with gr.Column(scale=1, min_width=120):
350
+ run_button = gr.Button("Generate", variant="primary")
351
 
352
  with gr.Row():
353
  target_ratio = gr.Radio(
354
+ label="Target Ratio",
355
  choices=["9:16", "16:9", "1:1", "Custom"],
356
+ value="9:16", # Default ratio
357
  scale=2
358
  )
359
+
360
  alignment_dropdown = gr.Dropdown(
361
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
362
  value="Middle",
363
+ label="Align Input Image"
364
  )
365
 
366
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
 
368
  with gr.Row():
369
  width_slider = gr.Slider(
370
  label="Target Width",
371
+ minimum=512, # Lowered minimum slightly
372
  maximum=1536,
373
+ step=64, # Steps of 64 common for SDXL
374
+ value=720, # Default width
375
  )
376
  height_slider = gr.Slider(
377
  label="Target Height",
378
+ minimum=512, # Lowered minimum slightly
379
  maximum=1536,
380
+ step=64, # Steps of 64
381
+ value=1280, # Default height
382
  )
383
+
384
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
385
+
386
  with gr.Group():
387
  overlap_percentage = gr.Slider(
388
  label="Mask overlap (%)",
389
+ info="Percentage of the input image edge to keep (reduces seams)",
390
  minimum=1,
391
  maximum=50,
392
+ value=10, # Default overlap
393
  step=1
394
  )
395
+ gr.Markdown("Select edges to apply overlap:")
396
  with gr.Row():
397
+ overlap_top = gr.Checkbox(label="Top", value=True)
398
+ overlap_right = gr.Checkbox(label="Right", value=True)
399
+ overlap_left = gr.Checkbox(label="Left", value=True)
400
+ overlap_bottom = gr.Checkbox(label="Bottom", value=True)
401
+
402
  with gr.Row():
403
  resize_option = gr.Radio(
404
+ label="Resize input image before placing",
405
+ info="Scale the input image relative to its fitted size",
406
  choices=["Full", "50%", "33%", "25%", "Custom"],
407
+ value="Full" # Default resize option
408
  )
409
  custom_resize_percentage = gr.Slider(
410
  label="Custom resize (%)",
 
412
  maximum=100,
413
  step=1,
414
  value=50,
415
+ visible=False # Initially hidden
416
  )
417
+
418
  gr.Examples(
419
  examples=[
420
+ ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720, "Middle"],
421
+ ["./examples/example_2.jpg", "RealVisXL V4.0 Lightning", 1440, 810, "Left"],
422
+ ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Top"],
423
+ ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Bottom"],
424
  ],
425
+ inputs=[input_image, model_selector, width_slider, height_slider, alignment_dropdown],
426
+ label="Examples (Prompt is optional)"
427
  )
428
 
429
+ with gr.Column(scale=3): # Output column
430
  result = gr.Image(
431
  interactive=False,
432
  label="Generated Image",
433
  format="png",
434
  )
435
+ history_gallery = gr.Gallery(
436
+ label="History",
437
+ columns=4, # Adjust columns as needed
438
+ object_fit="contain",
439
+ interactive=False,
440
+ show_label=True,
441
+ allow_preview=True,
442
+ preview=True
443
+ )
444
 
445
+
446
+ # --- Event Listeners ---
447
+
448
+ # Update sliders and accordion based on ratio selection
449
  target_ratio.change(
450
  fn=preload_presets,
451
  inputs=[target_ratio, width_slider, height_slider],
 
453
  queue=False
454
  )
455
 
456
+ # Update ratio selection based on slider changes
457
  width_slider.change(
458
  fn=select_the_right_preset,
459
  inputs=[width_slider, height_slider],
460
  outputs=[target_ratio],
461
  queue=False
462
  )
 
463
  height_slider.change(
464
  fn=select_the_right_preset,
465
  inputs=[width_slider, height_slider],
 
467
  queue=False
468
  )
469
 
470
+ # Show/hide custom resize slider
471
  resize_option.change(
472
  fn=toggle_custom_resize_slider,
473
  inputs=[resize_option],
474
  outputs=[custom_resize_percentage],
475
  queue=False
476
  )
477
+
478
+ # Define inputs for the main inference function
479
+ infer_inputs = [
480
+ model_selector, input_image, width_slider, height_slider, overlap_percentage,
481
+ num_inference_steps, resize_option, custom_resize_percentage, prompt_input,
482
+ alignment_dropdown, overlap_left, overlap_right, overlap_top, overlap_bottom
483
+ ]
484
+
485
+ # --- Run Button Click ---
486
  run_button.click(
487
  fn=clear_result,
488
  inputs=None,
489
+ outputs=[result], # Clear only the main result image
490
+ queue=False # Clearing should be fast
491
  ).then(
492
  fn=infer,
493
+ inputs=infer_inputs,
494
+ outputs=[result], # Output to the main result image
 
 
 
 
495
  ).then(
496
+ fn=update_history, # Use the specific update function
497
+ inputs=[result, history_gallery], # Pass the result and current history
498
+ outputs=[history_gallery], # Update the history gallery
499
  )
500
 
501
+ # --- Prompt Submit (Enter Key) ---
502
  prompt_input.submit(
503
+ fn=clear_result,
504
  inputs=None,
505
+ outputs=[result],
506
+ queue=False
507
  ).then(
508
  fn=infer,
509
+ inputs=infer_inputs,
510
+ outputs=[result],
 
 
 
 
511
  ).then(
512
+ fn=update_history,
513
  inputs=[result, history_gallery],
514
+ outputs=[history_gallery],
515
  )
516
 
517
+ # --- Launch App ---
518
+ # Make sure you have example images at the specified paths or remove/update the gr.Examples section
519
+ # Create an 'examples' directory and place images like 'example_1.webp', 'example_2.jpg', 'example_3.jpg' inside it.
520
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)