theSure commited on
Commit
69dbad3
·
verified ·
1 Parent(s): ba2f501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -91
app.py CHANGED
@@ -15,44 +15,44 @@ from diffusers.utils import load_image
15
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
  pipe = None
17
  torch.set_grad_enabled(False)
18
- image_path = mask_path =None
19
- # image_examples = [
20
- # [
21
- # "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
22
- # "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
23
- # ],
24
- # [
25
- # "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
26
- # "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
27
- # ],
28
- # [
29
- # "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
30
- # "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
31
- # ],
32
- # [
33
- # "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
34
- # "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
35
- # ],
36
- # [
37
- # "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
38
- # "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
39
- # ],
40
- # [
41
- # "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
42
- # "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
43
- # ],
44
- # [
45
- # "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
46
- # "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
47
- # ]
48
 
49
- # ]
 
50
 
51
  base_model_path = 'black-forest-labs/FLUX.1-dev'
52
  lora_path = 'theSure/Omnieraser'
53
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
54
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
55
- # enable image inputs
56
  with torch.no_grad():
57
  initial_input_channels = transformer.config.in_channels
58
  new_linear = torch.nn.Linear(
@@ -90,64 +90,48 @@ def set_seed(seed):
90
  @spaces.GPU
91
  def predict(
92
  input_image,
 
93
  prompt,
94
  ddim_steps,
95
  seed,
96
  scale,
97
- image_paths,
98
- mask_paths
99
 
100
  ):
101
- global image_path, mask_path
102
- gr.Info(str(f"Set seed = {seed}"))
103
- if image_paths is not None:
104
- input_image["background"] = load_image(image_paths).convert("RGB")
105
- input_image["layers"][0] = load_image(mask_paths).convert("RGB")
106
-
107
- size1, size2 = input_image["background"].convert("RGB").size
108
- icc_profile = input_image["background"].info.get('icc_profile')
109
  if icc_profile:
110
  gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
111
  srgb_profile = ImageCms.createProfile("sRGB")
112
  io_handle = io.BytesIO(icc_profile)
113
  src_profile = ImageCms.ImageCmsProfile(io_handle)
114
- input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
115
- input_image["background"].info.pop('icc_profile', None)
116
 
117
  if size1 < size2:
118
- input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
119
  else:
120
- input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
121
 
122
- img = np.array(input_image["background"].convert("RGB"))
123
 
124
  W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
125
  H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
126
 
127
- input_image["background"] = input_image["background"].resize((H, W))
128
- input_image["layers"][0] = input_image["layers"][0].resize((H, W))
129
 
130
  if seed == -1:
131
  seed = random.randint(1, 2147483647)
132
  set_seed(random.randint(1, 2147483647))
133
  else:
134
  set_seed(seed)
135
- if image_paths is None:
136
- img=input_image["layers"][0]
137
- img_data = np.array(img)
138
- alpha_channel = img_data[:, :, 3]
139
- white_background = np.ones_like(alpha_channel) * 255
140
- gray_image = white_background.copy()
141
- gray_image[alpha_channel == 0] = 0
142
- gray_image_pil = Image.fromarray(gray_image).convert('L')
143
- else:
144
- gray_image_pil = input_image["layers"][0]
145
  base_model_path = 'black-forest-labs/FLUX.1-dev'
146
  lora_path = 'theSure/Omnieraser'
147
  result = pipe(
148
  prompt=prompt,
149
- control_image=input_image["background"].convert("RGB"),
150
- control_mask=gray_image_pil.convert("RGB"),
151
  width=H,
152
  height=W,
153
  num_inference_steps=ddim_steps,
@@ -156,19 +140,19 @@ def predict(
156
  max_sequence_length=512,
157
  ).images[0]
158
 
159
- mask_np = np.array(input_image["layers"][0].convert("RGB"))
160
- red = np.array(input_image["background"]).astype("float") * 1
161
  red[:, :, 0] = 180.0
162
  red[:, :, 2] = 0
163
  red[:, :, 1] = 0
164
- result_m = np.array(input_image["background"])
165
  result_m = Image.fromarray(
166
  (
167
  result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
168
  ).astype("uint8")
169
  )
170
 
171
- dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
172
 
173
  dict_out = [result]
174
  image_path = None
@@ -178,21 +162,19 @@ def predict(
178
 
179
  def infer(
180
  input_image,
 
181
  ddim_steps,
182
  seed,
183
  scale,
184
  removal_prompt,
185
 
186
  ):
187
- img_path = image_path
188
- msk_path = mask_path
189
  return predict(input_image,
 
190
  removal_prompt,
191
  ddim_steps,
192
  seed,
193
  scale,
194
- img_path,
195
- msk_path
196
  )
197
 
198
  def process_example(image_paths, mask_paths):
@@ -288,13 +270,8 @@ with gr.Blocks(
288
  gr.Markdown("## 📥 Input Panel")
289
 
290
  with gr.Group():
291
- input_image = gr.Sketchpad(
292
- sources=["upload"],
293
- type="pil",
294
- label="Upload & Annotate",
295
- elem_id="custom-image",
296
- interactive=True
297
- )
298
  with gr.Row(variant="compact"):
299
  run_button = gr.Button(
300
  "🚀 Start Processing",
@@ -311,21 +288,18 @@ with gr.Blocks(
311
  step=1,
312
  info="-1 for random generation"
313
  )
314
- # with gr.Column(variant="panel"):
315
- # gr.Markdown("### 🖼️ Example Gallery", elem_classes=["example-title"])
316
- # example = gr.Examples(
317
- # examples=image_examples,
318
- # inputs=[
319
- # gr.Image(label="Image", type="filepath",visible=False),
320
- # gr.Image(label="Mask", type="filepath",visible=False)
321
- # ],
322
- # outputs=[input_image],
323
- # fn=process_example,
324
- # run_on_click=True,
325
- # examples_per_page=10,
326
- # label="Click any example to load",
327
- # elem_id="inline-examples"
328
- # )
329
 
330
  with gr.Column(scale=1, variant="panel"):
331
  gr.Markdown("## 📤 Output Panel")
@@ -351,6 +325,7 @@ with gr.Blocks(
351
  fn=infer,
352
  inputs=[
353
  input_image,
 
354
  ddim_steps,
355
  seed,
356
  scale,
 
15
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
  pipe = None
17
  torch.set_grad_enabled(False)
18
+
19
+ image_examples = [
20
+ [
21
+ "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
22
+ "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
23
+ ],
24
+ [
25
+ "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
26
+ "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
27
+ ],
28
+ [
29
+ "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
30
+ "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
31
+ ],
32
+ [
33
+ "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
34
+ "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
35
+ ],
36
+ [
37
+ "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
38
+ "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
39
+ ],
40
+ [
41
+ "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
42
+ "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
43
+ ],
44
+ [
45
+ "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
46
+ "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
47
+ ]
48
 
49
+ ]
50
+
51
 
52
  base_model_path = 'black-forest-labs/FLUX.1-dev'
53
  lora_path = 'theSure/Omnieraser'
54
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
55
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
 
56
  with torch.no_grad():
57
  initial_input_channels = transformer.config.in_channels
58
  new_linear = torch.nn.Linear(
 
90
  @spaces.GPU
91
  def predict(
92
  input_image,
93
+ uploaded_mask
94
  prompt,
95
  ddim_steps,
96
  seed,
97
  scale,
 
 
98
 
99
  ):
100
+ gr.Info(str(f"Set seed = {seed}"))
101
+ size1, size2 = input_image.convert("RGB").size
102
+ icc_profile = input_image.info.get('icc_profile')
 
 
 
 
 
103
  if icc_profile:
104
  gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
105
  srgb_profile = ImageCms.createProfile("sRGB")
106
  io_handle = io.BytesIO(icc_profile)
107
  src_profile = ImageCms.ImageCmsProfile(io_handle)
108
+ input_image = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
109
+ input_image.info.pop('icc_profile', None)
110
 
111
  if size1 < size2:
112
+ input_image = input_image.convert("RGB").resize((1024, int(size2 / size1 * 1024)))
113
  else:
114
+ input_image = input_image.convert("RGB").resize((int(size1 / size2 * 1024), 1024))
115
 
116
+ img = np.array(input_image.convert("RGB"))
117
 
118
  W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
119
  H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
120
 
121
+ input_image = input_imageresize((H, W))
122
+ uploaded_mask = uploaded_mask.resize((H, W))
123
 
124
  if seed == -1:
125
  seed = random.randint(1, 2147483647)
126
  set_seed(random.randint(1, 2147483647))
127
  else:
128
  set_seed(seed)
 
 
 
 
 
 
 
 
 
 
129
  base_model_path = 'black-forest-labs/FLUX.1-dev'
130
  lora_path = 'theSure/Omnieraser'
131
  result = pipe(
132
  prompt=prompt,
133
+ control_image=input_image.convert("RGB"),
134
+ control_mask=uploaded_mask.convert("RGB"),
135
  width=H,
136
  height=W,
137
  num_inference_steps=ddim_steps,
 
140
  max_sequence_length=512,
141
  ).images[0]
142
 
143
+ mask_np = np.array(uploaded_mask.convert("RGB"))
144
+ red = np.array(input_image).astype("float") * 1
145
  red[:, :, 0] = 180.0
146
  red[:, :, 2] = 0
147
  red[:, :, 1] = 0
148
+ result_m = np.array(input_image)
149
  result_m = Image.fromarray(
150
  (
151
  result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
152
  ).astype("uint8")
153
  )
154
 
155
+ dict_res = [input_image, uploaded_mask, result_m, result]
156
 
157
  dict_out = [result]
158
  image_path = None
 
162
 
163
  def infer(
164
  input_image,
165
+ uploaded_mask,
166
  ddim_steps,
167
  seed,
168
  scale,
169
  removal_prompt,
170
 
171
  ):
 
 
172
  return predict(input_image,
173
+ uploaded_mask
174
  removal_prompt,
175
  ddim_steps,
176
  seed,
177
  scale,
 
 
178
  )
179
 
180
  def process_example(image_paths, mask_paths):
 
270
  gr.Markdown("## 📥 Input Panel")
271
 
272
  with gr.Group():
273
+ image_input = gr.Image(label="Upload Image", type="pil", image_mode="RGB")
274
+ uploaded_mask = gr.Image(label="Upload Mask", type="pil", image_mode="L")
 
 
 
 
 
275
  with gr.Row(variant="compact"):
276
  run_button = gr.Button(
277
  "🚀 Start Processing",
 
288
  step=1,
289
  info="-1 for random generation"
290
  )
291
+ with gr.Column(variant="panel"):
292
+ gr.Markdown("### 🖼️ Example Gallery", elem_classes=["example-title"])
293
+ example = gr.Examples(
294
+ examples=image_examples,
295
+ inputs=[
296
+ image_input, uploaded_mask
297
+ ],
298
+ outputs=[inpaint_result, gallery],
299
+ examples_per_page=10,
300
+ label="Click any example to load",
301
+ elem_id="inline-examples"
302
+ )
 
 
 
303
 
304
  with gr.Column(scale=1, variant="panel"):
305
  gr.Markdown("## 📤 Output Panel")
 
325
  fn=infer,
326
  inputs=[
327
  input_image,
328
+ uploaded_mask
329
  ddim_steps,
330
  seed,
331
  scale,