theSure commited on
Commit
ba2f501
·
verified ·
1 Parent(s): 585ea4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -87
app.py CHANGED
@@ -7,55 +7,52 @@ import random
7
  import spaces
8
  import gradio as gr
9
  import numpy as np
 
10
  from PIL import Image, ImageCms
11
  import torch
12
  from diffusers import FluxTransformer2DModel
13
  from diffusers.utils import load_image
14
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
15
-
16
- # 初始化模型部分
17
  pipe = None
18
  torch.set_grad_enabled(False)
19
-
20
- # 示例数据
21
- image_examples = [
22
- [
23
- "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
24
- "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
25
- ],
26
- [
27
- "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
28
- "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
29
- ],
30
- [
31
- "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
32
- "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
33
- ],
34
- [
35
- "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
36
- "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
37
- ],
38
- [
39
- "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
40
- "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
41
- ],
42
- [
43
- "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
44
- "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
45
- ],
46
- [
47
- "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
48
- "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
49
- ]
50
 
51
- ]
52
 
53
- # 模型加载代码(保持不变)
54
  base_model_path = 'black-forest-labs/FLUX.1-dev'
55
  lora_path = 'theSure/Omnieraser'
56
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
57
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
58
-
59
  with torch.no_grad():
60
  initial_input_channels = transformer.config.in_channels
61
  new_linear = torch.nn.Linear(
@@ -71,7 +68,6 @@ with torch.no_grad():
71
  new_linear.bias.copy_(transformer.x_embedder.bias)
72
  transformer.x_embedder = new_linear
73
  transformer.register_to_config(in_channels=initial_input_channels*4)
74
-
75
  pipe = FluxControlRemovalPipeline.from_pretrained(
76
  base_model_path,
77
  transformer=transformer,
@@ -83,7 +79,6 @@ gr.Info(str(f"Inject LoRA: {lora_path}"))
83
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
84
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
85
 
86
- # 辅助函数
87
  @spaces.GPU
88
  def set_seed(seed):
89
  torch.manual_seed(seed)
@@ -92,7 +87,6 @@ def set_seed(seed):
92
  np.random.seed(seed)
93
  random.seed(seed)
94
 
95
- # 主要处理函数
96
  @spaces.GPU
97
  def predict(
98
  input_image,
@@ -100,14 +94,16 @@ def predict(
100
  ddim_steps,
101
  seed,
102
  scale,
103
- image_state, # 使用State替代全局变量
104
- mask_state # 使用State替代全局变量
 
105
  ):
106
- if image_state is not None and mask_state is not None:
107
- input_image["background"] = load_image(image_state).convert("RGB")
108
- input_image["layers"][0] = load_image(mask_state).convert("RGB")
 
 
109
 
110
- # 保持原有图像处理逻辑不变
111
  size1, size2 = input_image["background"].convert("RGB").size
112
  icc_profile = input_image["background"].info.get('icc_profile')
113
  if icc_profile:
@@ -118,17 +114,25 @@ def predict(
118
  input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
119
  input_image["background"].info.pop('icc_profile', None)
120
 
121
- # ... 保持原有图像尺寸调整逻辑不变
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # 保持原有seed处理逻辑
124
  if seed == -1:
125
  seed = random.randint(1, 2147483647)
126
- set_seed(seed)
127
  else:
128
  set_seed(seed)
129
-
130
- # 保持原有mask处理逻辑
131
- if image_state is None:
132
  img=input_image["layers"][0]
133
  img_data = np.array(img)
134
  alpha_channel = img_data[:, :, 3]
@@ -138,8 +142,8 @@ def predict(
138
  gray_image_pil = Image.fromarray(gray_image).convert('L')
139
  else:
140
  gray_image_pil = input_image["layers"][0]
141
-
142
- # 保持原有生成逻辑
143
  result = pipe(
144
  prompt=prompt,
145
  control_image=input_image["background"].convert("RGB"),
@@ -152,7 +156,6 @@ def predict(
152
  max_sequence_length=512,
153
  ).images[0]
154
 
155
- # 保持原有后处理逻辑
156
  mask_np = np.array(input_image["layers"][0].convert("RGB"))
157
  red = np.array(input_image["background"]).astype("float") * 1
158
  red[:, :, 0] = 180.0
@@ -166,19 +169,42 @@ def predict(
166
  )
167
 
168
  dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
 
169
  dict_out = [result]
170
-
 
171
  return dict_out, dict_res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- # 示例处理函数
174
  def process_example(image_paths, mask_paths):
 
175
  image = Image.open(image_paths).convert("RGB")
176
  mask = Image.open(mask_paths).convert("L")
177
  black_background = Image.new("RGB", image.size, (0, 0, 0))
178
  masked_image = Image.composite(black_background, image, mask)
179
- return masked_image, image_paths, mask_paths # 返回路径到State
180
-
181
- # 界面布局(保持原有CSS和布局逻辑)
 
182
  custom_css = """
183
 
184
  .contain { max-width: 1200px !important; }
@@ -235,7 +261,6 @@ custom_css = """
235
  .panel { height: 100%; }
236
  """
237
 
238
-
239
  with gr.Blocks(
240
  css=custom_css,
241
  theme=gr.themes.Soft(
@@ -245,20 +270,23 @@ with gr.Blocks(
245
  ),
246
  title="Omnieraser"
247
  ) as demo:
248
- # 添加状态存储
249
- image_state = gr.State()
250
- mask_state = gr.State()
251
 
252
- # 保持原有组件声明
253
  ddim_steps = gr.Slider(visible=False, value=28)
254
  scale = gr.Slider(visible=False, value=3.5)
255
  seed = gr.Slider(visible=False, value=-1)
256
  removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
257
 
258
- # 保持原有界面布局
 
 
 
 
 
259
  with gr.Row(equal_height=False):
260
  with gr.Column(scale=1, variant="panel"):
261
  gr.Markdown("## 📥 Input Panel")
 
262
  with gr.Group():
263
  input_image = gr.Sketchpad(
264
  sources=["upload"],
@@ -268,7 +296,11 @@ with gr.Blocks(
268
  interactive=True
269
  )
270
  with gr.Row(variant="compact"):
271
- run_button = gr.Button("🚀 Start Processing", variant="primary", size="lg")
 
 
 
 
272
  with gr.Group():
273
  gr.Markdown("### ⚙️ Control Parameters")
274
  seed = gr.Slider(
@@ -279,21 +311,21 @@ with gr.Blocks(
279
  step=1,
280
  info="-1 for random generation"
281
  )
282
- with gr.Column(variant="panel"):
283
- gr.Markdown("### 🖼️ Example Gallery", elem_classes=["example-title"])
284
- example = gr.Examples(
285
- examples=image_examples,
286
- inputs=[
287
- gr.Image(label="Image", type="filepath",visible=False),
288
- gr.Image(label="Mask", type="filepath",visible=False)
289
- ],
290
- outputs=[input_image, image_state, mask_state], # 更新状态输出
291
- fn=process_example,
292
- run_on_click=True,
293
- examples_per_page=10,
294
- label="Click any example to load",
295
- elem_id="inline-examples"
296
- )
297
 
298
  with gr.Column(scale=1, variant="panel"):
299
  gr.Markdown("## 📤 Output Panel")
@@ -306,6 +338,7 @@ with gr.Blocks(
306
  preview=True,
307
  object_fit="contain"
308
  )
 
309
  with gr.Tab("Visualization Steps"):
310
  gallery = gr.Gallery(
311
  label="Workflow Steps",
@@ -314,19 +347,16 @@ with gr.Blocks(
314
  object_fit="contain"
315
  )
316
 
317
- # 更新按钮点击事件
318
  run_button.click(
319
- fn=lambda i, d, s, sc, rp, img, msk: predict(i, rp, d, s, sc, img, msk),
320
  inputs=[
321
  input_image,
322
  ddim_steps,
323
  seed,
324
  scale,
325
  removal_prompt,
326
- image_state, # 添加状态输入
327
- mask_state # 添加状态输入
328
  ],
329
  outputs=[inpaint_result, gallery]
330
  )
331
-
332
- demo.launch()
 
7
  import spaces
8
  import gradio as gr
9
  import numpy as np
10
+
11
  from PIL import Image, ImageCms
12
  import torch
13
  from diffusers import FluxTransformer2DModel
14
  from diffusers.utils import load_image
15
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
 
 
16
  pipe = None
17
  torch.set_grad_enabled(False)
18
+ image_path = mask_path =None
19
+ # image_examples = [
20
+ # [
21
+ # "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
22
+ # "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
23
+ # ],
24
+ # [
25
+ # "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
26
+ # "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
27
+ # ],
28
+ # [
29
+ # "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
30
+ # "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
31
+ # ],
32
+ # [
33
+ # "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
34
+ # "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
35
+ # ],
36
+ # [
37
+ # "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
38
+ # "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
39
+ # ],
40
+ # [
41
+ # "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
42
+ # "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
43
+ # ],
44
+ # [
45
+ # "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
46
+ # "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
47
+ # ]
 
48
 
49
+ # ]
50
 
 
51
  base_model_path = 'black-forest-labs/FLUX.1-dev'
52
  lora_path = 'theSure/Omnieraser'
53
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
54
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
55
+ # enable image inputs
56
  with torch.no_grad():
57
  initial_input_channels = transformer.config.in_channels
58
  new_linear = torch.nn.Linear(
 
68
  new_linear.bias.copy_(transformer.x_embedder.bias)
69
  transformer.x_embedder = new_linear
70
  transformer.register_to_config(in_channels=initial_input_channels*4)
 
71
  pipe = FluxControlRemovalPipeline.from_pretrained(
72
  base_model_path,
73
  transformer=transformer,
 
79
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
80
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
81
 
 
82
  @spaces.GPU
83
  def set_seed(seed):
84
  torch.manual_seed(seed)
 
87
  np.random.seed(seed)
88
  random.seed(seed)
89
 
 
90
  @spaces.GPU
91
  def predict(
92
  input_image,
 
94
  ddim_steps,
95
  seed,
96
  scale,
97
+ image_paths,
98
+ mask_paths
99
+
100
  ):
101
+ global image_path, mask_path
102
+ gr.Info(str(f"Set seed = {seed}"))
103
+ if image_paths is not None:
104
+ input_image["background"] = load_image(image_paths).convert("RGB")
105
+ input_image["layers"][0] = load_image(mask_paths).convert("RGB")
106
 
 
107
  size1, size2 = input_image["background"].convert("RGB").size
108
  icc_profile = input_image["background"].info.get('icc_profile')
109
  if icc_profile:
 
114
  input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
115
  input_image["background"].info.pop('icc_profile', None)
116
 
117
+ if size1 < size2:
118
+ input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
119
+ else:
120
+ input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
121
+
122
+ img = np.array(input_image["background"].convert("RGB"))
123
+
124
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
125
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
126
+
127
+ input_image["background"] = input_image["background"].resize((H, W))
128
+ input_image["layers"][0] = input_image["layers"][0].resize((H, W))
129
 
 
130
  if seed == -1:
131
  seed = random.randint(1, 2147483647)
132
+ set_seed(random.randint(1, 2147483647))
133
  else:
134
  set_seed(seed)
135
+ if image_paths is None:
 
 
136
  img=input_image["layers"][0]
137
  img_data = np.array(img)
138
  alpha_channel = img_data[:, :, 3]
 
142
  gray_image_pil = Image.fromarray(gray_image).convert('L')
143
  else:
144
  gray_image_pil = input_image["layers"][0]
145
+ base_model_path = 'black-forest-labs/FLUX.1-dev'
146
+ lora_path = 'theSure/Omnieraser'
147
  result = pipe(
148
  prompt=prompt,
149
  control_image=input_image["background"].convert("RGB"),
 
156
  max_sequence_length=512,
157
  ).images[0]
158
 
 
159
  mask_np = np.array(input_image["layers"][0].convert("RGB"))
160
  red = np.array(input_image["background"]).astype("float") * 1
161
  red[:, :, 0] = 180.0
 
169
  )
170
 
171
  dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
172
+
173
  dict_out = [result]
174
+ image_path = None
175
+ mask_path = None
176
  return dict_out, dict_res
177
+
178
+
179
+ def infer(
180
+ input_image,
181
+ ddim_steps,
182
+ seed,
183
+ scale,
184
+ removal_prompt,
185
+
186
+ ):
187
+ img_path = image_path
188
+ msk_path = mask_path
189
+ return predict(input_image,
190
+ removal_prompt,
191
+ ddim_steps,
192
+ seed,
193
+ scale,
194
+ img_path,
195
+ msk_path
196
+ )
197
 
 
198
  def process_example(image_paths, mask_paths):
199
+ global image_path, mask_path
200
  image = Image.open(image_paths).convert("RGB")
201
  mask = Image.open(mask_paths).convert("L")
202
  black_background = Image.new("RGB", image.size, (0, 0, 0))
203
  masked_image = Image.composite(black_background, image, mask)
204
+
205
+ image_path = image_paths
206
+ mask_path = mask_paths
207
+ return masked_image
208
  custom_css = """
209
 
210
  .contain { max-width: 1200px !important; }
 
261
  .panel { height: 100%; }
262
  """
263
 
 
264
  with gr.Blocks(
265
  css=custom_css,
266
  theme=gr.themes.Soft(
 
270
  ),
271
  title="Omnieraser"
272
  ) as demo:
 
 
 
273
 
274
+
275
  ddim_steps = gr.Slider(visible=False, value=28)
276
  scale = gr.Slider(visible=False, value=3.5)
277
  seed = gr.Slider(visible=False, value=-1)
278
  removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
279
 
280
+ gr.Markdown("""
281
+ <div align="center">
282
+ <h1 style="font-size: 2.5em; margin-bottom: 0.5em;">🪄 Omnieraser</h1>
283
+ </div>
284
+ """)
285
+
286
  with gr.Row(equal_height=False):
287
  with gr.Column(scale=1, variant="panel"):
288
  gr.Markdown("## 📥 Input Panel")
289
+
290
  with gr.Group():
291
  input_image = gr.Sketchpad(
292
  sources=["upload"],
 
296
  interactive=True
297
  )
298
  with gr.Row(variant="compact"):
299
+ run_button = gr.Button(
300
+ "🚀 Start Processing",
301
+ variant="primary",
302
+ size="lg"
303
+ )
304
  with gr.Group():
305
  gr.Markdown("### ⚙️ Control Parameters")
306
  seed = gr.Slider(
 
311
  step=1,
312
  info="-1 for random generation"
313
  )
314
+ # with gr.Column(variant="panel"):
315
+ # gr.Markdown("### 🖼️ Example Gallery", elem_classes=["example-title"])
316
+ # example = gr.Examples(
317
+ # examples=image_examples,
318
+ # inputs=[
319
+ # gr.Image(label="Image", type="filepath",visible=False),
320
+ # gr.Image(label="Mask", type="filepath",visible=False)
321
+ # ],
322
+ # outputs=[input_image],
323
+ # fn=process_example,
324
+ # run_on_click=True,
325
+ # examples_per_page=10,
326
+ # label="Click any example to load",
327
+ # elem_id="inline-examples"
328
+ # )
329
 
330
  with gr.Column(scale=1, variant="panel"):
331
  gr.Markdown("## 📤 Output Panel")
 
338
  preview=True,
339
  object_fit="contain"
340
  )
341
+
342
  with gr.Tab("Visualization Steps"):
343
  gallery = gr.Gallery(
344
  label="Workflow Steps",
 
347
  object_fit="contain"
348
  )
349
 
 
350
  run_button.click(
351
+ fn=infer,
352
  inputs=[
353
  input_image,
354
  ddim_steps,
355
  seed,
356
  scale,
357
  removal_prompt,
 
 
358
  ],
359
  outputs=[inpaint_result, gallery]
360
  )
361
+
362
+ demo.launch()