theSure commited on
Commit
7163430
·
verified ·
1 Parent(s): e535e7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -72
app.py CHANGED
@@ -13,9 +13,12 @@ import torch
13
  from diffusers import FluxTransformer2DModel
14
  from diffusers.utils import load_image
15
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
 
 
16
  pipe = None
17
  torch.set_grad_enabled(False)
18
- image_path = mask_path =None
 
19
  image_examples = [
20
  [
21
  "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
@@ -48,11 +51,12 @@ image_examples = [
48
 
49
  ]
50
 
 
51
  base_model_path = 'black-forest-labs/FLUX.1-dev'
52
  lora_path = 'theSure/Omnieraser'
53
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
54
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
55
- # enable image inputs
56
  with torch.no_grad():
57
  initial_input_channels = transformer.config.in_channels
58
  new_linear = torch.nn.Linear(
@@ -68,6 +72,7 @@ with torch.no_grad():
68
  new_linear.bias.copy_(transformer.x_embedder.bias)
69
  transformer.x_embedder = new_linear
70
  transformer.register_to_config(in_channels=initial_input_channels*4)
 
71
  pipe = FluxControlRemovalPipeline.from_pretrained(
72
  base_model_path,
73
  transformer=transformer,
@@ -79,6 +84,7 @@ gr.Info(str(f"Inject LoRA: {lora_path}"))
79
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
80
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
81
 
 
82
  @spaces.GPU
83
  def set_seed(seed):
84
  torch.manual_seed(seed)
@@ -87,6 +93,7 @@ def set_seed(seed):
87
  np.random.seed(seed)
88
  random.seed(seed)
89
 
 
90
  @spaces.GPU
91
  def predict(
92
  input_image,
@@ -94,16 +101,14 @@ def predict(
94
  ddim_steps,
95
  seed,
96
  scale,
97
- image_paths,
98
- mask_paths
99
-
100
  ):
101
-
102
- gr.Info(str(f"Set seed = {seed}"))
103
- if image_paths is not None:
104
- input_image["background"] = load_image(image_paths).convert("RGB")
105
- input_image["layers"][0] = load_image(mask_paths).convert("RGB")
106
 
 
107
  size1, size2 = input_image["background"].convert("RGB").size
108
  icc_profile = input_image["background"].info.get('icc_profile')
109
  if icc_profile:
@@ -114,25 +119,17 @@ def predict(
114
  input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
115
  input_image["background"].info.pop('icc_profile', None)
116
 
117
- if size1 < size2:
118
- input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
119
- else:
120
- input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
121
-
122
- img = np.array(input_image["background"].convert("RGB"))
123
-
124
- W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
125
- H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
126
-
127
- input_image["background"] = input_image["background"].resize((H, W))
128
- input_image["layers"][0] = input_image["layers"][0].resize((H, W))
129
 
 
130
  if seed == -1:
131
  seed = random.randint(1, 2147483647)
132
- set_seed(random.randint(1, 2147483647))
133
  else:
134
  set_seed(seed)
135
- if image_paths is None:
 
 
136
  img=input_image["layers"][0]
137
  img_data = np.array(img)
138
  alpha_channel = img_data[:, :, 3]
@@ -142,8 +139,8 @@ def predict(
142
  gray_image_pil = Image.fromarray(gray_image).convert('L')
143
  else:
144
  gray_image_pil = input_image["layers"][0]
145
- base_model_path = 'black-forest-labs/FLUX.1-dev'
146
- lora_path = 'theSure/Omnieraser'
147
  result = pipe(
148
  prompt=prompt,
149
  control_image=input_image["background"].convert("RGB"),
@@ -156,6 +153,7 @@ def predict(
156
  max_sequence_length=512,
157
  ).images[0]
158
 
 
159
  mask_np = np.array(input_image["layers"][0].convert("RGB"))
160
  red = np.array(input_image["background"]).astype("float") * 1
161
  red[:, :, 0] = 180.0
@@ -169,42 +167,19 @@ def predict(
169
  )
170
 
171
  dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
172
-
173
  dict_out = [result]
174
- image_path = None
175
- mask_path = None
176
  return dict_out, dict_res
177
-
178
-
179
- def infer(
180
- input_image,
181
- ddim_steps,
182
- seed,
183
- scale,
184
- removal_prompt,
185
-
186
- ):
187
- img_path = image_path
188
- msk_path = mask_path
189
- return predict(input_image,
190
- removal_prompt,
191
- ddim_steps,
192
- seed,
193
- scale,
194
- img_path,
195
- msk_path
196
- )
197
 
 
198
  def process_example(image_paths, mask_paths):
199
- global image_path, mask_path
200
  image = Image.open(image_paths).convert("RGB")
201
  mask = Image.open(mask_paths).convert("L")
202
  black_background = Image.new("RGB", image.size, (0, 0, 0))
203
  masked_image = Image.composite(black_background, image, mask)
204
-
205
- image_path = image_paths
206
- mask_path = mask_paths
207
- return masked_image
208
  custom_css = """
209
 
210
  .contain { max-width: 1200px !important; }
@@ -261,6 +236,7 @@ custom_css = """
261
  .panel { height: 100%; }
262
  """
263
 
 
264
  with gr.Blocks(
265
  css=custom_css,
266
  theme=gr.themes.Soft(
@@ -270,23 +246,20 @@ with gr.Blocks(
270
  ),
271
  title="Omnieraser"
272
  ) as demo:
 
 
 
273
 
274
-
275
  ddim_steps = gr.Slider(visible=False, value=28)
276
  scale = gr.Slider(visible=False, value=3.5)
277
  seed = gr.Slider(visible=False, value=-1)
278
  removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
279
 
280
- gr.Markdown("""
281
- <div align="center">
282
- <h1 style="font-size: 2.5em; margin-bottom: 0.5em;">🪄 Omnieraser</h1>
283
- </div>
284
- """)
285
-
286
  with gr.Row(equal_height=False):
287
  with gr.Column(scale=1, variant="panel"):
288
  gr.Markdown("## 📥 Input Panel")
289
-
290
  with gr.Group():
291
  input_image = gr.Sketchpad(
292
  sources=["upload"],
@@ -296,11 +269,7 @@ with gr.Blocks(
296
  interactive=True
297
  )
298
  with gr.Row(variant="compact"):
299
- run_button = gr.Button(
300
- "🚀 Start Processing",
301
- variant="primary",
302
- size="lg"
303
- )
304
  with gr.Group():
305
  gr.Markdown("### ⚙️ Control Parameters")
306
  seed = gr.Slider(
@@ -319,7 +288,7 @@ with gr.Blocks(
319
  gr.Image(label="Image", type="filepath",visible=False),
320
  gr.Image(label="Mask", type="filepath",visible=False)
321
  ],
322
- outputs=[input_image],
323
  fn=process_example,
324
  run_on_click=True,
325
  examples_per_page=10,
@@ -338,7 +307,6 @@ with gr.Blocks(
338
  preview=True,
339
  object_fit="contain"
340
  )
341
-
342
  with gr.Tab("Visualization Steps"):
343
  gallery = gr.Gallery(
344
  label="Workflow Steps",
@@ -347,16 +315,19 @@ with gr.Blocks(
347
  object_fit="contain"
348
  )
349
 
 
350
  run_button.click(
351
- fn=infer,
352
  inputs=[
353
  input_image,
354
  ddim_steps,
355
  seed,
356
  scale,
357
  removal_prompt,
 
 
358
  ],
359
  outputs=[inpaint_result, gallery]
360
  )
361
-
362
- demo.launch()
 
13
  from diffusers import FluxTransformer2DModel
14
  from diffusers.utils import load_image
15
  from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
+
17
+ # 初始化模型部分
18
  pipe = None
19
  torch.set_grad_enabled(False)
20
+
21
+ # 示例数据
22
  image_examples = [
23
  [
24
  "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
 
51
 
52
  ]
53
 
54
+ # 模型加载代码(保持不变)
55
  base_model_path = 'black-forest-labs/FLUX.1-dev'
56
  lora_path = 'theSure/Omnieraser'
57
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
58
  gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
59
+
60
  with torch.no_grad():
61
  initial_input_channels = transformer.config.in_channels
62
  new_linear = torch.nn.Linear(
 
72
  new_linear.bias.copy_(transformer.x_embedder.bias)
73
  transformer.x_embedder = new_linear
74
  transformer.register_to_config(in_channels=initial_input_channels*4)
75
+
76
  pipe = FluxControlRemovalPipeline.from_pretrained(
77
  base_model_path,
78
  transformer=transformer,
 
84
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
85
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
86
 
87
+ # 辅助函数
88
  @spaces.GPU
89
  def set_seed(seed):
90
  torch.manual_seed(seed)
 
93
  np.random.seed(seed)
94
  random.seed(seed)
95
 
96
+ # 主要处理函数
97
  @spaces.GPU
98
  def predict(
99
  input_image,
 
101
  ddim_steps,
102
  seed,
103
  scale,
104
+ image_state, # 使用State替代全局变量
105
+ mask_state # 使用State替代全局变量
 
106
  ):
107
+ if image_state is not None and mask_state is not None:
108
+ input_image["background"] = load_image(image_state).convert("RGB")
109
+ input_image["layers"][0] = load_image(mask_state).convert("RGB")
 
 
110
 
111
+ # 保持原有图像处理逻辑不变
112
  size1, size2 = input_image["background"].convert("RGB").size
113
  icc_profile = input_image["background"].info.get('icc_profile')
114
  if icc_profile:
 
119
  input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
120
  input_image["background"].info.pop('icc_profile', None)
121
 
122
+ # ... 保持原有图像尺寸调整逻辑不变
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # 保持原有seed处理逻辑
125
  if seed == -1:
126
  seed = random.randint(1, 2147483647)
127
+ set_seed(seed)
128
  else:
129
  set_seed(seed)
130
+
131
+ # 保持原有mask处理逻辑
132
+ if image_state is None:
133
  img=input_image["layers"][0]
134
  img_data = np.array(img)
135
  alpha_channel = img_data[:, :, 3]
 
139
  gray_image_pil = Image.fromarray(gray_image).convert('L')
140
  else:
141
  gray_image_pil = input_image["layers"][0]
142
+
143
+ # 保持原有生成逻辑
144
  result = pipe(
145
  prompt=prompt,
146
  control_image=input_image["background"].convert("RGB"),
 
153
  max_sequence_length=512,
154
  ).images[0]
155
 
156
+ # 保持原有后处理逻辑
157
  mask_np = np.array(input_image["layers"][0].convert("RGB"))
158
  red = np.array(input_image["background"]).astype("float") * 1
159
  red[:, :, 0] = 180.0
 
167
  )
168
 
169
  dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
 
170
  dict_out = [result]
171
+
 
172
  return dict_out, dict_res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # ���例处理函数
175
  def process_example(image_paths, mask_paths):
 
176
  image = Image.open(image_paths).convert("RGB")
177
  mask = Image.open(mask_paths).convert("L")
178
  black_background = Image.new("RGB", image.size, (0, 0, 0))
179
  masked_image = Image.composite(black_background, image, mask)
180
+ return masked_image, image_paths, mask_paths # 返回路径到State
181
+
182
+ # 界面布局(保持原有CSS和布局逻辑)
 
183
  custom_css = """
184
 
185
  .contain { max-width: 1200px !important; }
 
236
  .panel { height: 100%; }
237
  """
238
 
239
+
240
  with gr.Blocks(
241
  css=custom_css,
242
  theme=gr.themes.Soft(
 
246
  ),
247
  title="Omnieraser"
248
  ) as demo:
249
+ # 添加状态存储
250
+ image_state = gr.State()
251
+ mask_state = gr.State()
252
 
253
+ # 保持原有组件声明
254
  ddim_steps = gr.Slider(visible=False, value=28)
255
  scale = gr.Slider(visible=False, value=3.5)
256
  seed = gr.Slider(visible=False, value=-1)
257
  removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
258
 
259
+ # 保持原有界面布局
 
 
 
 
 
260
  with gr.Row(equal_height=False):
261
  with gr.Column(scale=1, variant="panel"):
262
  gr.Markdown("## 📥 Input Panel")
 
263
  with gr.Group():
264
  input_image = gr.Sketchpad(
265
  sources=["upload"],
 
269
  interactive=True
270
  )
271
  with gr.Row(variant="compact"):
272
+ run_button = gr.Button("🚀 Start Processing", variant="primary", size="lg")
 
 
 
 
273
  with gr.Group():
274
  gr.Markdown("### ⚙️ Control Parameters")
275
  seed = gr.Slider(
 
288
  gr.Image(label="Image", type="filepath",visible=False),
289
  gr.Image(label="Mask", type="filepath",visible=False)
290
  ],
291
+ outputs=[input_image, image_state, mask_state], # 更新状态输出
292
  fn=process_example,
293
  run_on_click=True,
294
  examples_per_page=10,
 
307
  preview=True,
308
  object_fit="contain"
309
  )
 
310
  with gr.Tab("Visualization Steps"):
311
  gallery = gr.Gallery(
312
  label="Workflow Steps",
 
315
  object_fit="contain"
316
  )
317
 
318
+ # 更新按钮点击事件
319
  run_button.click(
320
+ fn=lambda i, d, s, sc, rp, img, msk: predict(i, rp, d, s, sc, img, msk),
321
  inputs=[
322
  input_image,
323
  ddim_steps,
324
  seed,
325
  scale,
326
  removal_prompt,
327
+ image_state, # 添加状态输入
328
+ mask_state # 添加状态输入
329
  ],
330
  outputs=[inpaint_result, gallery]
331
  )
332
+
333
+ demo.launch()