prithivMLmods commited on
Commit
04cce22
·
verified ·
1 Parent(s): 11d7c13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -123
app.py CHANGED
@@ -2,172 +2,258 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
- # (Assume ControlNet manual load or from_pretrained is already working)
6
- from controlnet_union import ControlNetModel_Union
7
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
8
  from gradio_imageslider import ImageSlider
9
  from huggingface_hub import hf_hub_download
10
 
 
 
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
- # --- Load ControlNet and SDXL Fill Pipeline ---
15
- # (Either manual download or via from_pretrained)
16
- controlnet_model = ControlNetModel_Union.from_pretrained(
17
  "xinsir/controlnet-union-sdxl-1.0",
18
- torch_dtype=torch.float16,
19
- variant="fp16"
20
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  vae = AutoencoderKL.from_pretrained(
23
- "madebyollin/sdxl-vae-fp16-fix",
24
- torch_dtype=torch.float16
25
  ).to("cuda")
26
 
27
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
28
  "SG161222/RealVisXL_V5.0_Lightning",
29
  torch_dtype=torch.float16,
30
  vae=vae,
31
- controlnet=controlnet_model,
32
  variant="fp16",
33
  ).to("cuda")
 
34
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
35
 
36
- # --- Utility functions ---
37
  def can_expand(source_width, source_height, target_width, target_height, alignment):
 
38
  if alignment in ("Left", "Right") and source_width >= target_width:
39
  return False
40
  if alignment in ("Top", "Bottom") and source_height >= target_height:
41
  return False
42
  return True
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def prepare_image_and_mask(image, width, height, overlap_percentage,
46
- resize_option, custom_resize_percentage,
47
- alignment, overlap_left, overlap_right,
48
- overlap_top, overlap_bottom):
49
- target = (width, height)
50
- scale = min(target[0] / image.width, target[1] / image.height)
51
- w, h = int(image.width * scale), int(image.height * scale)
52
- src = image.resize((w, h), Image.LANCZOS)
53
-
54
- # Resize percentage
55
- if resize_option == "Full": pct = 100
56
- elif resize_option == "50%": pct = 50
57
- elif resize_option == "33%": pct = 33
58
- elif resize_option == "25%": pct = 25
59
- else: pct = custom_resize_percentage
60
-
61
- rw, rh = max(int(src.width * pct / 100), 64), max(int(src.height * pct / 100), 64)
62
- src = src.resize((rw, rh), Image.LANCZOS)
63
-
64
- ox = max(int(rw * overlap_percentage / 100), 1)
65
- oy = max(int(rh * overlap_percentage / 100), 1)
66
-
67
- # Margins
68
- if alignment == "Middle": mx, my = (width - rw)//2, (height - rh)//2
69
- elif alignment == "Left": mx, my = 0, (height - rh)//2
70
- elif alignment == "Right": mx, my = width - rw, (height - rh)//2
71
- elif alignment == "Top": mx, my = (width - rw)//2, 0
72
- else: mx, my = (width - rw)//2, height - rh
73
-
74
- mx, my = max(0, min(mx, width - rw)), max(0, min(my, height - rh))
75
-
76
- bg = Image.new("RGB", target, (255,255,255))
77
- bg.paste(src, (mx, my))
78
-
79
- mask = Image.new("L", target, 255)
80
- d = ImageDraw.Draw(mask)
81
-
82
- lx = mx + (ox if overlap_left else 2)
83
- rx = mx + rw - (ox if overlap_right else 2)
84
- ty = my + (oy if overlap_top else 2)
85
- by = my + rh - (oy if overlap_bottom else 2)
86
-
87
- # Edge adjustments
88
- if alignment == "Left": lx = mx + (ox if overlap_left else 0)
89
- if alignment == "Right": rx = mx + rw - (ox if overlap_right else 0)
90
- if alignment == "Top": ty = my + (oy if overlap_top else 0)
91
- if alignment == "Bottom": by = my + rh - (oy if overlap_bottom else 0)
92
-
93
- d.rectangle([(lx, ty), (rx, by)], fill=0)
94
- return bg, mask
95
-
96
-
97
- def preview_image_and_mask(*args):
98
- bg, mask = prepare_image_and_mask(*args)
99
- vis = bg.copy().convert("RGBA")
100
- red = Image.new("RGBA", bg.size, (255,0,0,64))
101
- overlay = Image.new("RGBA", bg.size, (0,0,0,0))
102
- overlay.paste(red, (0,0), mask)
103
- return Image.alpha_composite(vis, overlay)
104
-
105
- # --- Fixed infer: return list for slider ---
106
  @spaces.GPU(duration=24)
107
- def infer(image, width, height, overlap_percentage, num_inference_steps,
108
- resize_option, custom_resize_percentage, prompt_input,
109
- alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
110
- background, mask = prepare_image_and_mask(
111
- image, width, height, overlap_percentage,
112
- resize_option, custom_resize_percentage,
113
- alignment, overlap_left, overlap_right,
114
- overlap_top, overlap_bottom
115
- )
116
  if not can_expand(background.width, background.height, width, height, alignment):
117
  alignment = "Middle"
118
 
119
- hole = background.copy()
120
- hole.paste(0, (0,0), mask)
121
 
122
  final_prompt = f"{prompt_input} , high quality, 4k"
123
- embeds = pipe.encode_prompt(final_prompt, "cuda", True)
124
-
125
- # Run pipeline and grab last frame
126
- gen = pipe(
127
- prompt_embeds=embeds[0],
128
- negative_prompt_embeds=embeds[1],
129
- pooled_prompt_embeds=embeds[2],
130
- negative_pooled_prompt_embeds=embeds[3],
131
- image=hole,
132
- num_inference_steps=num_inference_steps
133
- )
134
- last = None
135
- for img in gen:
136
- last = img
137
 
138
- out = last.convert("RGBA")
139
- hole.paste(out, (0,0), mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Return a list: [input_hole_image, final_output]
142
- return [background, hole]
143
 
 
144
 
145
  def clear_result():
 
146
  return gr.update(value=None)
147
 
148
- def preload_presets(ratio, w, h):
149
- if ratio == "9:16": return 720, 1280, gr.update()
150
- if ratio == "16:9": return 1280, 720, gr.update()
151
- if ratio == "1:1": return 1024, 1024, gr.update()
152
- return w, h, gr.update(open=True)
153
-
154
- def select_the_right_preset(w, h):
155
- if (w,h) == (720,1280): return "9:16"
156
- if (w,h) == (1280,720): return "16:9"
157
- if (w,h) == (1024,1024): return "1:1"
158
- return "Custom"
159
-
160
- def toggle_custom_resize_slider(opt):
161
- return gr.update(visible=(opt=="Custom"))
162
-
163
- def update_history(img, history):
164
- history = history or []
165
- history.insert(0, img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  return history
167
 
168
- css = ".gradio-container { width: 1200px !important; }"
169
- title = "<h1 align='center'>Diffusers Image Outpaint Lightning</h1>"
 
 
 
 
 
170
 
 
 
171
  with gr.Blocks(css=css) as demo:
172
  gr.HTML(title)
173
  with gr.Row():
 
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
+ from diffusers.models.model_loading_utils import load_state_dict
 
 
6
  from gradio_imageslider import ImageSlider
7
  from huggingface_hub import hf_hub_download
8
 
9
+ from controlnet_union import ControlNetModel_Union
10
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
+
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
15
+ config_file = hf_hub_download(
 
 
16
  "xinsir/controlnet-union-sdxl-1.0",
17
+ filename="config_promax.json",
18
+ )
19
+
20
+ config = ControlNetModel_Union.load_config(config_file)
21
+ controlnet_model = ControlNetModel_Union.from_config(config)
22
+ model_file = hf_hub_download(
23
+ "xinsir/controlnet-union-sdxl-1.0",
24
+ filename="diffusion_pytorch_model_promax.safetensors",
25
+ )
26
+
27
+ sstate_dict = load_state_dict(model_file)
28
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
+ controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
+ )
31
+ model.to(device="cuda", dtype=torch.float16)
32
+ #----------------------
33
 
34
  vae = AutoencoderKL.from_pretrained(
35
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
 
36
  ).to("cuda")
37
 
38
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
39
  "SG161222/RealVisXL_V5.0_Lightning",
40
  torch_dtype=torch.float16,
41
  vae=vae,
42
+ controlnet=model,
43
  variant="fp16",
44
  ).to("cuda")
45
+
46
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
47
 
 
48
  def can_expand(source_width, source_height, target_width, target_height, alignment):
49
+ """Checks if the image can be expanded based on the alignment."""
50
  if alignment in ("Left", "Right") and source_width >= target_width:
51
  return False
52
  if alignment in ("Top", "Bottom") and source_height >= target_height:
53
  return False
54
  return True
55
 
56
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
57
+ target_size = (width, height)
58
+
59
+ # Calculate the scaling factor to fit the image within the target size
60
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
61
+ new_width = int(image.width * scale_factor)
62
+ new_height = int(image.height * scale_factor)
63
+
64
+ # Resize the source image to fit within target size
65
+ source = image.resize((new_width, new_height), Image.LANCZOS)
66
+
67
+ # Apply resize option using percentages
68
+ if resize_option == "Full":
69
+ resize_percentage = 100
70
+ elif resize_option == "50%":
71
+ resize_percentage = 50
72
+ elif resize_option == "33%":
73
+ resize_percentage = 33
74
+ elif resize_option == "25%":
75
+ resize_percentage = 25
76
+ else: # Custom
77
+ resize_percentage = custom_resize_percentage
78
+
79
+ # Calculate new dimensions based on percentage
80
+ resize_factor = resize_percentage / 100
81
+ new_width = int(source.width * resize_factor)
82
+ new_height = int(source.height * resize_factor)
83
+
84
+ # Ensure minimum size of 64 pixels
85
+ new_width = max(new_width, 64)
86
+ new_height = max(new_height, 64)
87
+
88
+ # Resize the image
89
+ source = source.resize((new_width, new_height), Image.LANCZOS)
90
+
91
+ # Calculate the overlap in pixels based on the percentage
92
+ overlap_x = int(new_width * (overlap_percentage / 100))
93
+ overlap_y = int(new_height * (overlap_percentage / 100))
94
+
95
+ # Ensure minimum overlap of 1 pixel
96
+ overlap_x = max(overlap_x, 1)
97
+ overlap_y = max(overlap_y, 1)
98
+
99
+ # Calculate margins based on alignment
100
+ if alignment == "Middle":
101
+ margin_x = (target_size[0] - new_width) // 2
102
+ margin_y = (target_size[1] - new_height) // 2
103
+ elif alignment == "Left":
104
+ margin_x = 0
105
+ margin_y = (target_size[1] - new_height) // 2
106
+ elif alignment == "Right":
107
+ margin_x = target_size[0] - new_width
108
+ margin_y = (target_size[1] - new_height) // 2
109
+ elif alignment == "Top":
110
+ margin_x = (target_size[0] - new_width) // 2
111
+ margin_y = 0
112
+ elif alignment == "Bottom":
113
+ margin_x = (target_size[0] - new_width) // 2
114
+ margin_y = target_size[1] - new_height
115
+
116
+ # Adjust margins to eliminate gaps
117
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
118
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
119
+
120
+ # Create a new background image and paste the resized source image
121
+ background = Image.new('RGB', target_size, (255, 255, 255))
122
+ background.paste(source, (margin_x, margin_y))
123
+
124
+ # Create the mask
125
+ mask = Image.new('L', target_size, 255)
126
+ mask_draw = ImageDraw.Draw(mask)
127
+
128
+ # Calculate overlap areas
129
+ white_gaps_patch = 2
130
+
131
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
132
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
133
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
134
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
135
+
136
+ if alignment == "Left":
137
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
138
+ elif alignment == "Right":
139
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
140
+ elif alignment == "Top":
141
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
142
+ elif alignment == "Bottom":
143
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
144
+
145
+
146
+ # Draw the mask
147
+ mask_draw.rectangle([
148
+ (left_overlap, top_overlap),
149
+ (right_overlap, bottom_overlap)
150
+ ], fill=0)
151
+
152
+ return background, mask
153
+
154
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
155
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
156
+
157
+ # Create a preview image showing the mask
158
+ preview = background.copy().convert('RGBA')
159
+
160
+ # Create a semi-transparent red overlay
161
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
162
+
163
+ # Convert black pixels in the mask to semi-transparent red
164
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
165
+ red_mask.paste(red_overlay, (0, 0), mask)
166
+
167
+ # Overlay the red mask on the background
168
+ preview = Image.alpha_composite(preview, red_mask)
169
+
170
+ return preview
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  @spaces.GPU(duration=24)
173
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
174
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
175
+
 
 
 
 
 
 
176
  if not can_expand(background.width, background.height, width, height, alignment):
177
  alignment = "Middle"
178
 
179
+ cnet_image = background.copy()
180
+ cnet_image.paste(0, (0, 0), mask)
181
 
182
  final_prompt = f"{prompt_input} , high quality, 4k"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ (
185
+ prompt_embeds,
186
+ negative_prompt_embeds,
187
+ pooled_prompt_embeds,
188
+ negative_pooled_prompt_embeds,
189
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
190
+
191
+ for image in pipe(
192
+ prompt_embeds=prompt_embeds,
193
+ negative_prompt_embeds=negative_prompt_embeds,
194
+ pooled_prompt_embeds=pooled_prompt_embeds,
195
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
196
+ image=cnet_image,
197
+ num_inference_steps=num_inference_steps
198
+ ):
199
+ yield cnet_image, image
200
 
201
+ image = image.convert("RGBA")
202
+ cnet_image.paste(image, (0, 0), mask)
203
 
204
+ yield background, cnet_image
205
 
206
  def clear_result():
207
+ """Clears the result ImageSlider."""
208
  return gr.update(value=None)
209
 
210
+ def preload_presets(target_ratio, ui_width, ui_height):
211
+ """Updates the width and height sliders based on the selected aspect ratio."""
212
+ if target_ratio == "9:16":
213
+ changed_width = 720
214
+ changed_height = 1280
215
+ return changed_width, changed_height, gr.update()
216
+ elif target_ratio == "16:9":
217
+ changed_width = 1280
218
+ changed_height = 720
219
+ return changed_width, changed_height, gr.update()
220
+ elif target_ratio == "1:1":
221
+ changed_width = 1024
222
+ changed_height = 1024
223
+ return changed_width, changed_height, gr.update()
224
+ elif target_ratio == "Custom":
225
+ return ui_width, ui_height, gr.update(open=True)
226
+
227
+ def select_the_right_preset(user_width, user_height):
228
+ if user_width == 720 and user_height == 1280:
229
+ return "9:16"
230
+ elif user_width == 1280 and user_height == 720:
231
+ return "16:9"
232
+ elif user_width == 1024 and user_height == 1024:
233
+ return "1:1"
234
+ else:
235
+ return "Custom"
236
+
237
+ def toggle_custom_resize_slider(resize_option):
238
+ return gr.update(visible=(resize_option == "Custom"))
239
+
240
+ def update_history(new_image, history):
241
+ """Updates the history gallery with the new image."""
242
+ if history is None:
243
+ history = []
244
+ history.insert(0, new_image)
245
  return history
246
 
247
+ css = """
248
+ .gradio-container {
249
+ width: 1200px !important;
250
+ }
251
+ h1 { text-align: center; }
252
+ footer { visibility: hidden; }
253
+ """
254
 
255
+ title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
256
+ """
257
  with gr.Blocks(css=css) as demo:
258
  gr.HTML(title)
259
  with gr.Row():