prithivMLmods commited on
Commit
fe264a3
·
verified ·
1 Parent(s): 19d58d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -116
app.py CHANGED
@@ -11,65 +11,67 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
- # Load configuration and models
 
 
 
 
15
  config_file = hf_hub_download(
16
  "xinsir/controlnet-union-sdxl-1.0",
17
  filename="config_promax.json",
18
  )
19
-
20
  config = ControlNetModel_Union.load_config(config_file)
21
  controlnet_model = ControlNetModel_Union.from_config(config)
22
  model_file = hf_hub_download(
23
  "xinsir/controlnet-union-sdxl-1.0",
24
  filename="diffusion_pytorch_model_promax.safetensors",
25
  )
26
-
27
  sstate_dict = load_state_dict(model_file)
28
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
- model.to(device="cuda", dtype=torch.float16)
32
 
33
- vae = AutoencoderKL.from_pretrained(
34
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
35
- ).to("cuda")
 
 
36
 
37
- # Initially load the default pipeline
 
38
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
39
- "SG161222/RealVisXL_V5.0_Lightning",
40
  torch_dtype=torch.float16,
41
  vae=vae,
42
- controlnet=model,
43
  variant="fp16",
44
  ).to("cuda")
45
-
46
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
47
 
48
- def load_model(selected_model):
49
- global pipe
50
- model_path = f"SG161222/{selected_model}"
51
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
52
- model_path,
53
  torch_dtype=torch.float16,
54
  vae=vae,
55
- controlnet=model,
56
  variant="fp16",
57
  ).to("cuda")
58
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
59
- return f"Loaded model: {selected_model}"
60
 
 
61
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
62
  target_size = (width, height)
63
 
64
- # Calculate the scaling factor to fit the image within the target size
65
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
66
  new_width = int(image.width * scale_factor)
67
  new_height = int(image.height * scale_factor)
68
 
69
- # Resize the source image to fit within target size
70
  source = image.resize((new_width, new_height), Image.LANCZOS)
71
 
72
- # Apply resize option using percentages
73
  if resize_option == "Full":
74
  resize_percentage = 100
75
  elif resize_option == "50%":
@@ -81,27 +83,21 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
81
  else: # Custom
82
  resize_percentage = custom_resize_percentage
83
 
84
- # Calculate new dimensions based on percentage
85
  resize_factor = resize_percentage / 100
86
  new_width = int(source.width * resize_factor)
87
  new_height = int(source.height * resize_factor)
88
 
89
- # Ensure minimum size of 64 pixels
90
  new_width = max(new_width, 64)
91
  new_height = max(new_height, 64)
92
 
93
- # Resize the image
94
  source = source.resize((new_width, new_height), Image.LANCZOS)
95
 
96
- # Calculate the overlap in pixels based on the percentage
97
  overlap_x = int(new_width * (overlap_percentage / 100))
98
  overlap_y = int(new_height * (overlap_percentage / 100))
99
 
100
- # Ensure minimum overlap of 1 pixel
101
  overlap_x = max(overlap_x, 1)
102
  overlap_y = max(overlap_y, 1)
103
 
104
- # Calculate margins based on alignment
105
  if alignment == "Middle":
106
  margin_x = (target_size[0] - new_width) // 2
107
  margin_y = (target_size[1] - new_height) // 2
@@ -118,19 +114,15 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
118
  margin_x = (target_size[0] - new_width) // 2
119
  margin_y = target_size[1] - new_height
120
 
121
- # Adjust margins to eliminate gaps
122
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
123
  margin_y = max(0, min(margin_y, target_size[1] - new_height))
124
 
125
- # Create a new background image and paste the resized source image
126
  background = Image.new('RGB', target_size, (255, 255, 255))
127
  background.paste(source, (margin_x, margin_y))
128
 
129
- # Create the mask
130
  mask = Image.new('L', target_size, 255)
131
  mask_draw = ImageDraw.Draw(mask)
132
 
133
- # Calculate overlap areas
134
  white_gaps_patch = 2
135
 
136
  left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
@@ -147,7 +139,6 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
147
  elif alignment == "Bottom":
148
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
149
 
150
- # Draw the mask
151
  mask_draw.rectangle([
152
  (left_overlap, top_overlap),
153
  (right_overlap, bottom_overlap)
@@ -155,8 +146,9 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
155
 
156
  return background, mask
157
 
 
158
  @spaces.GPU(duration=24)
159
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
160
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
161
 
162
  cnet_image = background.copy()
@@ -169,10 +161,9 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
169
  negative_prompt_embeds,
170
  pooled_prompt_embeds,
171
  negative_pooled_prompt_embeds,
172
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
173
 
174
- # Generate the image
175
- for image in pipe(
176
  prompt_embeds=prompt_embeds,
177
  negative_prompt_embeds=negative_prompt_embeds,
178
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -180,32 +171,25 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
180
  image=cnet_image,
181
  num_inference_steps=num_inference_steps
182
  ):
183
- pass # Wait for the generation to complete
184
- generated_image = image # Get the last image
185
 
186
  generated_image = generated_image.convert("RGBA")
187
  cnet_image.paste(generated_image, (0, 0), mask)
188
 
189
  return cnet_image
190
 
 
191
  def clear_result():
192
- """Clears the result Image."""
193
  return gr.update(value=None)
194
 
195
  def preload_presets(target_ratio, ui_width, ui_height):
196
- """Updates the width and height sliders based on the selected aspect ratio."""
197
  if target_ratio == "9:16":
198
- changed_width = 720
199
- changed_height = 1280
200
- return changed_width, changed_height, gr.update()
201
  elif target_ratio == "16:9":
202
- changed_width = 1280
203
- changed_height = 720
204
- return changed_width, changed_height, gr.update()
205
  elif target_ratio == "1:1":
206
- changed_width = 1024
207
- changed_height = 1024
208
- return changed_width, changed_height, gr.update()
209
  elif target_ratio == "Custom":
210
  return ui_width, ui_height, gr.update(open=True)
211
 
@@ -223,13 +207,12 @@ def toggle_custom_resize_slider(resize_option):
223
  return gr.update(visible=(resize_option == "Custom"))
224
 
225
  def update_history(new_image, history):
226
- """Updates the history gallery with the new image."""
227
  if history is None:
228
  history = []
229
  history.insert(0, new_image)
230
  return history
231
 
232
- # CSS and Title
233
  css = """
234
  h1 {
235
  text-align: center;
@@ -240,6 +223,7 @@ h1 {
240
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
241
  """
242
 
 
243
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
244
  with gr.Column():
245
  gr.HTML(title)
@@ -250,17 +234,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
250
  type="pil",
251
  label="Input Image"
252
  )
253
- model_selection = gr.Dropdown(
254
- choices=["RealVisXL_V5.0_Lightning", "RealVisXL_V4.0_Lightning"],
255
- value="RealVisXL_V5.0_Lightning",
256
- label="Select Model"
257
- )
258
  with gr.Row():
259
  with gr.Column(scale=2):
260
  prompt_input = gr.Textbox(label="Prompt (Optional)")
261
  with gr.Column(scale=1):
262
  run_button = gr.Button("Generate")
263
 
 
 
 
 
 
 
 
264
  with gr.Row():
265
  target_ratio = gr.Radio(
266
  label="Expected Ratio",
@@ -291,6 +278,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
291
  step=8,
292
  value=1280,
293
  )
 
294
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
295
  with gr.Group():
296
  overlap_percentage = gr.Slider(
@@ -320,7 +308,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
320
  value=50,
321
  visible=False
322
  )
323
- status_text = gr.Textbox(label="Status", interactive=False)
324
  gr.Examples(
325
  examples=[
326
  ["./examples/example_1.webp", 1280, 720, "Middle"],
@@ -339,62 +327,74 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
339
  )
340
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
341
 
342
- # Event handlers
343
- model_selection.change(fn=load_model, inputs=model_selection, outputs=status_text)
344
- target_ratio.change(
345
- fn=preload_presets,
346
- inputs=[target_ratio, width_slider, height_slider],
347
- outputs=[width_slider, height_slider, settings_panel],
348
- queue=False
349
- )
350
- width_slider.change(
351
- fn=select_the_right_preset,
352
- inputs=[width_slider, height_slider],
353
- outputs=[target_ratio],
354
- queue=False
355
- )
356
- height_slider.change(
357
- fn=select_the_right_preset,
358
- inputs=[width_slider, height_slider],
359
- outputs=[target_ratio],
360
- queue=False
361
- )
362
- resize_option.change(
363
- fn=toggle_custom_resize_slider,
364
- inputs=[resize_option],
365
- outputs=[custom_resize_percentage],
366
- queue=False
367
- )
368
- run_button.click(
369
- fn=clear_result,
370
- inputs=None,
371
- outputs=result,
372
- ).then(
373
- fn=infer,
374
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
375
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
376
- overlap_left, overlap_right, overlap_top, overlap_bottom],
377
- outputs=result,
378
- ).then(
379
- fn=lambda x, history: update_history(x, history),
380
- inputs=[result, history_gallery],
381
- outputs=history_gallery,
382
- )
383
- prompt_input.submit(
384
- fn=clear_result,
385
- inputs=None,
386
- outputs=result,
387
- ).then(
388
- fn=infer,
389
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
390
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
391
- overlap_left, overlap_right, overlap_top, overlap_bottom],
392
- outputs=result,
393
- ).then(
394
- fn=lambda x, history: update_history(x, history),
395
- inputs=[result, history_gallery],
396
- outputs=history_gallery,
397
- )
398
- demo.load(fn=load_model, inputs=model_selection, outputs=status_text)
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
+ # Load VAE and ControlNet (shared components)
15
+ vae = AutoencoderKL.from_pretrained(
16
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
17
+ ).to("cuda")
18
+
19
  config_file = hf_hub_download(
20
  "xinsir/controlnet-union-sdxl-1.0",
21
  filename="config_promax.json",
22
  )
 
23
  config = ControlNetModel_Union.load_config(config_file)
24
  controlnet_model = ControlNetModel_Union.from_config(config)
25
  model_file = hf_hub_download(
26
  "xinsir/controlnet-union-sdxl-1.0",
27
  filename="diffusion_pytorch_model_promax.safetensors",
28
  )
 
29
  sstate_dict = load_state_dict(model_file)
30
+ controlnet, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
  )
33
+ controlnet.to(device="cuda", dtype=torch.float16)
34
 
35
+ # Define available models
36
+ models = {
37
+ "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
38
+ "RealVisXL V4.0 Lightning": "SG161222/RealVisXL_V4.0_Lightning",
39
+ }
40
 
41
+ # Load default pipeline
42
+ default_model = "RealVisXL V5.0 Lightning"
43
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
44
+ models[default_model],
45
  torch_dtype=torch.float16,
46
  vae=vae,
47
+ controlnet=controlnet,
48
  variant="fp16",
49
  ).to("cuda")
 
50
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
51
 
52
+ # Function to load pipeline based on selected model
53
+ def load_pipeline(model_name):
54
+ repo_id = models[model_name]
55
+ new_pipe = StableDiffusionXLFillPipeline.from_pretrained(
56
+ repo_id,
57
  torch_dtype=torch.float16,
58
  vae=vae,
59
+ controlnet=controlnet,
60
  variant="fp16",
61
  ).to("cuda")
62
+ new_pipe.scheduler = TCDScheduler.from_config(new_pipe.scheduler.config)
63
+ return new_pipe
64
 
65
+ # Prepare image and mask function (unchanged)
66
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
67
  target_size = (width, height)
68
 
 
69
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
70
  new_width = int(image.width * scale_factor)
71
  new_height = int(image.height * scale_factor)
72
 
 
73
  source = image.resize((new_width, new_height), Image.LANCZOS)
74
 
 
75
  if resize_option == "Full":
76
  resize_percentage = 100
77
  elif resize_option == "50%":
 
83
  else: # Custom
84
  resize_percentage = custom_resize_percentage
85
 
 
86
  resize_factor = resize_percentage / 100
87
  new_width = int(source.width * resize_factor)
88
  new_height = int(source.height * resize_factor)
89
 
 
90
  new_width = max(new_width, 64)
91
  new_height = max(new_height, 64)
92
 
 
93
  source = source.resize((new_width, new_height), Image.LANCZOS)
94
 
 
95
  overlap_x = int(new_width * (overlap_percentage / 100))
96
  overlap_y = int(new_height * (overlap_percentage / 100))
97
 
 
98
  overlap_x = max(overlap_x, 1)
99
  overlap_y = max(overlap_y, 1)
100
 
 
101
  if alignment == "Middle":
102
  margin_x = (target_size[0] - new_width) // 2
103
  margin_y = (target_size[1] - new_height) // 2
 
114
  margin_x = (target_size[0] - new_width) // 2
115
  margin_y = target_size[1] - new_height
116
 
 
117
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
118
  margin_y = max(0, min(margin_y, target_size[1] - new_height))
119
 
 
120
  background = Image.new('RGB', target_size, (255, 255, 255))
121
  background.paste(source, (margin_x, margin_y))
122
 
 
123
  mask = Image.new('L', target_size, 255)
124
  mask_draw = ImageDraw.Draw(mask)
125
 
 
126
  white_gaps_patch = 2
127
 
128
  left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
 
139
  elif alignment == "Bottom":
140
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
141
 
 
142
  mask_draw.rectangle([
143
  (left_overlap, top_overlap),
144
  (right_overlap, bottom_overlap)
 
146
 
147
  return background, mask
148
 
149
+ # Updated inference function to use selected pipeline
150
  @spaces.GPU(duration=24)
151
+ def infer(pipeline, image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
152
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
153
 
154
  cnet_image = background.copy()
 
161
  negative_prompt_embeds,
162
  pooled_prompt_embeds,
163
  negative_pooled_prompt_embeds,
164
+ ) = pipeline.encode_prompt(final_prompt, "cuda", True)
165
 
166
+ for image in pipeline(
 
167
  prompt_embeds=prompt_embeds,
168
  negative_prompt_embeds=negative_prompt_embeds,
169
  pooled_prompt_embeds=pooled_prompt_embeds,
 
171
  image=cnet_image,
172
  num_inference_steps=num_inference_steps
173
  ):
174
+ pass
175
+ generated_image = image
176
 
177
  generated_image = generated_image.convert("RGBA")
178
  cnet_image.paste(generated_image, (0, 0), mask)
179
 
180
  return cnet_image
181
 
182
+ # Utility functions (unchanged)
183
  def clear_result():
 
184
  return gr.update(value=None)
185
 
186
  def preload_presets(target_ratio, ui_width, ui_height):
 
187
  if target_ratio == "9:16":
188
+ return 720, 1280, gr.update()
 
 
189
  elif target_ratio == "16:9":
190
+ return 1280, 720, gr.update()
 
 
191
  elif target_ratio == "1:1":
192
+ return 1024, 1024, gr.update()
 
 
193
  elif target_ratio == "Custom":
194
  return ui_width, ui_height, gr.update(open=True)
195
 
 
207
  return gr.update(visible=(resize_option == "Custom"))
208
 
209
  def update_history(new_image, history):
 
210
  if history is None:
211
  history = []
212
  history.insert(0, new_image)
213
  return history
214
 
215
+ # CSS and title (unchanged)
216
  css = """
217
  h1 {
218
  text-align: center;
 
223
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
224
  """
225
 
226
+ # Gradio interface with model selection
227
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
228
  with gr.Column():
229
  gr.HTML(title)
 
234
  type="pil",
235
  label="Input Image"
236
  )
237
+
 
 
 
 
238
  with gr.Row():
239
  with gr.Column(scale=2):
240
  prompt_input = gr.Textbox(label="Prompt (Optional)")
241
  with gr.Column(scale=1):
242
  run_button = gr.Button("Generate")
243
 
244
+ with gr.Row():
245
+ model_selector = gr.Dropdown(
246
+ label="Select Model",
247
+ choices=list(models.keys()),
248
+ value="RealVisXL V5.0 Lightning",
249
+ )
250
+
251
  with gr.Row():
252
  target_ratio = gr.Radio(
253
  label="Expected Ratio",
 
278
  step=8,
279
  value=1280,
280
  )
281
+
282
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
283
  with gr.Group():
284
  overlap_percentage = gr.Slider(
 
308
  value=50,
309
  visible=False
310
  )
311
+
312
  gr.Examples(
313
  examples=[
314
  ["./examples/example_1.webp", 1280, 720, "Middle"],
 
327
  )
328
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
329
 
330
+ # State to hold the current pipeline
331
+ pipeline_state = gr.State(value=pipe)
332
+
333
+ # Update pipeline when model is selected
334
+ model_selector.change(
335
+ fn=load_pipeline,
336
+ inputs=model_selector,
337
+ outputs=pipeline_state,
338
+ )
339
+
340
+ target_ratio.change(
341
+ fn=preload_presets,
342
+ inputs=[target_ratio, width_slider, height_slider],
343
+ outputs=[width_slider, height_slider, settings_panel],
344
+ queue=False
345
+ )
346
+
347
+ width_slider.change(
348
+ fn=select_the_right_preset,
349
+ inputs=[width_slider, height_slider],
350
+ outputs=[target_ratio],
351
+ queue=False
352
+ )
353
+
354
+ height_slider.change(
355
+ fn=select_the_right_preset,
356
+ inputs=[width_slider, height_slider],
357
+ outputs=[target_ratio],
358
+ queue=False
359
+ )
360
+
361
+ resize_option.change(
362
+ fn=toggle_custom_resize_slider,
363
+ inputs=[resize_option],
364
+ outputs=[custom_resize_percentage],
365
+ queue=False
366
+ )
367
+
368
+ run_button.click(
369
+ fn=clear_result,
370
+ inputs=None,
371
+ outputs=result,
372
+ ).then(
373
+ fn=infer,
374
+ inputs=[pipeline_state, input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
375
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
376
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
377
+ outputs=result,
378
+ ).then(
379
+ fn=lambda x, history: update_history(x, history),
380
+ inputs=[result, history_gallery],
381
+ outputs=history_gallery,
382
+ )
383
+
384
+ prompt_input.submit(
385
+ fn=clear_result,
386
+ inputs=None,
387
+ outputs=result,
388
+ ).then(
389
+ fn=infer,
390
+ inputs=[pipeline_state, input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
391
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
392
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
393
+ outputs=result,
394
+ ).then(
395
+ fn=lambda x, history: update_history(x, history),
396
+ inputs=[result, history_gallery],
397
+ outputs=history_gallery,
398
+ )
399
 
400
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)