theSure commited on
Commit
fa19523
Β·
verified Β·
1 Parent(s): cc07737

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -75
app.py CHANGED
@@ -1,93 +1,373 @@
1
- import gradio as gr
 
 
 
2
  import torch
3
- from diffusers import StableDiffusionInpaintPipeline
4
  import spaces
5
- from PIL import Image
6
  import numpy as np
7
- import random
8
- import os
9
 
10
- DESCRIPTION = "# Omnieraser\nRemove anything from any image using the [FLUX](https://huggingface.co/lllyasviel/flux) model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- model_id = "lllyasviel/flux"
13
- lora_weights = "lllyasviel/flux-inpainting-internal"
14
 
15
- def load_pipeline():
16
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
17
- model_id,
18
- torch_dtype=torch.float16,
19
- variant="fp16"
20
  ).to("cuda")
21
- pipe.load_lora_weights(lora_weights)
 
 
 
 
22
  return pipe
 
 
 
 
 
 
 
23
 
24
- def inference(pipe, image, mask):
25
- image = image.convert("RGB").resize((512, 512))
26
- mask = mask.convert("RGB").resize((512, 512))
27
-
28
- generator = torch.Generator("cuda").manual_seed(random.randint(0, 999999))
29
- image = pipe(prompt="", image=image, mask_image=mask, guidance_scale=7.5, generator=generator).images[0]
30
- return image
31
-
32
- def process_example(example, pipe):
33
- image_path, mask_path = example
34
- image = Image.open(image_path).convert("RGB")
35
- mask = Image.open(mask_path).convert("RGB")
36
- return inference(pipe, image, mask)
37
-
38
- def get_random_examples(dataset_dir="examples"):
39
- image_dir = os.path.join(dataset_dir, "images")
40
- mask_dir = os.path.join(dataset_dir, "masks")
41
- files = os.listdir(image_dir)
42
- random.shuffle(files)
43
- examples = [
44
- [os.path.join(image_dir, f), os.path.join(mask_dir, f)] for f in files if os.path.exists(os.path.join(mask_dir, f))
45
- ]
46
- return examples[:30]
47
 
48
- def build_ui(pipe):
49
- with gr.Blocks(css="style.css") as demo:
50
- gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- with gr.Row():
53
- with gr.Column():
54
- input_image = gr.Image(label="Input", type="pil")
55
- mask_image = gr.Image(label="Mask", type="pil")
56
 
57
- with gr.Row():
58
- submit = gr.Button("Run", elem_id="submit-button")
59
 
60
- with gr.Column():
61
- result_image = gr.Image(label="Output")
62
 
63
- submit.click(
64
- fn=lambda img, msk: inference(pipe, img, msk),
65
- inputs=[input_image, mask_image],
66
- outputs=result_image
67
- )
68
 
69
- gr.Markdown("## Examples")
70
-
71
- image_examples = get_random_examples()
72
- example = gr.Examples(
73
- examples=image_examples,
74
- inputs=[
75
- gr.Image(label="Image", type="filepath", visible=False),
76
- gr.Image(label="Mask", type="filepath", visible=False)
77
- ],
78
- outputs=[input_image],
79
- fn=lambda example: process_example(example, pipe),
80
- run_on_click=True,
81
- label="Click any example to load",
82
- elem_id="inline-examples"
83
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- gr.Markdown("Try drawing over objects you want to remove.")
 
 
 
 
 
 
86
 
87
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # 加载 pipe 幢运葌 UI
90
- if __name__ == "__main__":
91
- pipe = load_pipeline()
92
- demo = build_ui(pipe)
93
- demo.launch()
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import uuid
5
  import torch
6
+ import random
7
  import spaces
8
+ import gradio as gr
9
  import numpy as np
 
 
10
 
11
+ from PIL import Image, ImageCms
12
+ import torch
13
+ from diffusers import FluxTransformer2DModel
14
+ from diffusers.utils import load_image
15
+ from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
+
17
+ torch.set_grad_enabled(False)
18
+ image_path = mask_path = None
19
+ image_examples = [...]
20
+ image_path = mask_path =None
21
+ image_examples = [
22
+ [
23
+ "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
24
+ "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
25
+ ],
26
+ [
27
+ "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
28
+ "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
29
+ ],
30
+ [
31
+ "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
32
+ "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
33
+ ],
34
+ [
35
+ "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
36
+ "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
37
+ ],
38
+ [
39
+ "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
40
+ "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
41
+ ],
42
+ [
43
+ "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
44
+ "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
45
+ ],
46
+ [
47
+ "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
48
+ "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
49
+ ]
50
+
51
+ ]
52
+
53
+ @spaces.GPU(enable_queue=True)
54
+ def load_model(base_model_path, lora_path):
55
+ transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
56
+ gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
57
+ # enable image inputs
58
+ with torch.no_grad():
59
+ initial_input_channels = transformer.config.in_channels
60
+ new_linear = torch.nn.Linear(
61
+ transformer.x_embedder.in_features*4,
62
+ transformer.x_embedder.out_features,
63
+ bias=transformer.x_embedder.bias is not None,
64
+ dtype=transformer.dtype,
65
+ device=transformer.device,
66
+ )
67
+ new_linear.weight.zero_()
68
+ new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight)
69
+
70
+ if transformer.x_embedder.bias is not None:
71
+ new_linear.bias.copy_(transformer.x_embedder.bias)
72
 
73
+ transformer.x_embedder = new_linear
74
+ transformer.register_to_config(in_channels=initial_input_channels*4)
75
 
76
+ pipe = FluxControlRemovalPipeline.from_pretrained(
77
+ base_model_path,
78
+ transformer=transformer,
79
+ torch_dtype=torch.bfloat16
 
80
  ).to("cuda")
81
+ pipe.transformer.to(torch.bfloat16)
82
+ gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
83
+ gr.Info(str(f"Inject LoRA: {lora_path}"))
84
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
85
+ gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
86
  return pipe
87
+ @spaces.GPU(enable_queue=True)
88
+ def set_seed(seed):
89
+ torch.manual_seed(seed)
90
+ torch.cuda.manual_seed(seed)
91
+ torch.cuda.manual_seed_all(seed)
92
+ np.random.seed(seed)
93
+ random.seed(seed)
94
 
95
+ @spaces.GPU(enable_queue=True)
96
+ def predict(
97
+ pipe
98
+ input_image,
99
+ prompt,
100
+ ddim_steps,
101
+ seed,
102
+ scale,
103
+ image_paths,
104
+ mask_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ ):
107
+ global image_path, mask_path
108
+ gr.Info(str(f"Set seed = {seed}"))
109
+ if image_paths is not None:
110
+ input_image["background"] = load_image(image_paths).convert("RGB")
111
+ input_image["layers"][0] = load_image(mask_paths).convert("RGB")
112
+
113
+ size1, size2 = input_image["background"].convert("RGB").size
114
+ icc_profile = input_image["background"].info.get('icc_profile')
115
+ if icc_profile:
116
+ gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
117
+ srgb_profile = ImageCms.createProfile("sRGB")
118
+ io_handle = io.BytesIO(icc_profile)
119
+ src_profile = ImageCms.ImageCmsProfile(io_handle)
120
+ input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
121
+ input_image["background"].info.pop('icc_profile', None)
122
 
123
+ if size1 < size2:
124
+ input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
125
+ else:
126
+ input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
127
 
128
+ img = np.array(input_image["background"].convert("RGB"))
 
129
 
130
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
131
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
132
 
133
+ input_image["background"] = input_image["background"].resize((H, W))
134
+ input_image["layers"][0] = input_image["layers"][0].resize((H, W))
 
 
 
135
 
136
+ if seed == -1:
137
+ seed = random.randint(1, 2147483647)
138
+ set_seed(random.randint(1, 2147483647))
139
+ else:
140
+ set_seed(seed)
141
+ if image_paths is None:
142
+ img=input_image["layers"][0]
143
+ img_data = np.array(img)
144
+ alpha_channel = img_data[:, :, 3]
145
+ white_background = np.ones_like(alpha_channel) * 255
146
+ gray_image = white_background.copy()
147
+ gray_image[alpha_channel == 0] = 0
148
+ gray_image_pil = Image.fromarray(gray_image).convert('L')
149
+ else:
150
+ gray_image_pil = input_image["layers"][0]
151
+ result = pipe(
152
+ prompt=prompt,
153
+ control_image=input_image["background"].convert("RGB"),
154
+ control_mask=gray_image_pil.convert("RGB"),
155
+ width=H,
156
+ height=W,
157
+ num_inference_steps=ddim_steps,
158
+ generator=torch.Generator("cuda").manual_seed(seed),
159
+ guidance_scale=scale,
160
+ max_sequence_length=512,
161
+ ).images[0]
162
+
163
+ mask_np = np.array(input_image["layers"][0].convert("RGB"))
164
+ red = np.array(input_image["background"]).astype("float") * 1
165
+ red[:, :, 0] = 180.0
166
+ red[:, :, 2] = 0
167
+ red[:, :, 1] = 0
168
+ result_m = np.array(input_image["background"])
169
+ result_m = Image.fromarray(
170
+ (
171
+ result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
172
+ ).astype("uint8")
173
+ )
174
+
175
+ dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
176
+
177
+ dict_out = [result]
178
+ image_path = None
179
+ mask_path = None
180
+ return dict_out, dict_res
181
+
182
+
183
+ def infer(
184
+ pipe,
185
+ input_image,
186
+ ddim_steps,
187
+ seed,
188
+ scale,
189
+ removal_prompt,
190
+
191
+ ):
192
+ img_path = image_path
193
+ msk_path = mask_path
194
+ return predict(pipe,
195
+ input_image,
196
+ removal_prompt,
197
+ ddim_steps,
198
+ seed,
199
+ scale,
200
+ img_path,
201
+ msk_path
202
+ )
203
+
204
+ def process_example(image_paths, mask_paths):
205
+ global image_path, mask_path
206
+ image = Image.open(image_paths).convert("RGB")
207
+ mask = Image.open(mask_paths).convert("L")
208
+ black_background = Image.new("RGB", image.size, (0, 0, 0))
209
+ masked_image = Image.composite(black_background, image, mask)
210
+
211
+ image_path = image_paths
212
+ mask_path = mask_paths
213
+ return masked_image
214
+ custom_css = """
215
+
216
+ .contain { max-width: 1200px !important; }
217
+
218
+ .custom-image {
219
+ border: 2px dashed #7e22ce !important;
220
+ border-radius: 12px !important;
221
+ transition: all 0.3s ease !important;
222
+ }
223
+ .custom-image:hover {
224
+ border-color: #9333ea !important;
225
+ box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
226
+ }
227
+
228
+ .btn-primary {
229
+ background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
230
+ border: none !important;
231
+ color: white !important;
232
+ border-radius: 8px !important;
233
+ }
234
+ #inline-examples {
235
+ border: 1px solid #e2e8f0 !important;
236
+ border-radius: 12px !important;
237
+ padding: 16px !important;
238
+ margin-top: 8px !important;
239
+ }
240
+
241
+ #inline-examples .thumbnail {
242
+ border-radius: 8px !important;
243
+ transition: transform 0.2s ease !important;
244
+ }
245
+
246
+ #inline-examples .thumbnail:hover {
247
+ transform: scale(1.05);
248
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
249
+ }
250
+
251
+ .example-title h3 {
252
+ margin: 0 0 12px 0 !important;
253
+ color: #475569 !important;
254
+ font-size: 1.1em !important;
255
+ display: flex !important;
256
+ align-items: center !important;
257
+ }
258
+
259
+ .example-title h3::before {
260
+ content: "πŸ“š";
261
+ margin-right: 8px;
262
+ font-size: 1.2em;
263
+ }
264
+
265
+ .row { align-items: stretch !important; }
266
+
267
+ .panel { height: 100%; }
268
+ """
269
+
270
+ with gr.Blocks(
271
+ css=custom_css,
272
+ theme=gr.themes.Soft(
273
+ primary_hue="purple",
274
+ secondary_hue="purple",
275
+ font=[gr.themes.GoogleFont('Inter'), 'sans-serif']
276
+ ),
277
+ title="Omnieraser"
278
+ ) as demo:
279
+ base_model_path = 'black-forest-labs/FLUX.1-dev'
280
+ lora_path = 'theSure/Omnieraser'
281
+ pipe = None
282
+ pipe = load_model(base_model_path=base_model_path, lora_path=lora_path)
283
+
284
+ ddim_steps = gr.Slider(visible=False, value=28)
285
+ scale = gr.Slider(visible=False, value=3.5)
286
+ seed = gr.Slider(visible=False, value=-1)
287
+ removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
288
+
289
+ gr.Markdown("""
290
+ <div align="center">
291
+ <h1 style="font-size: 2.5em; margin-bottom: 0.5em;">πŸͺ„ Omnieraser</h1>
292
+ </div>
293
+ """)
294
+
295
+ with gr.Row(equal_height=False):
296
+ with gr.Column(scale=1, variant="panel"):
297
+ gr.Markdown("## πŸ“₯ Input Panel")
298
+
299
+ with gr.Group():
300
+ input_image = gr.Sketchpad(
301
+ sources=["upload"],
302
+ type="pil",
303
+ label="Upload & Annotate",
304
+ elem_id="custom-image",
305
+ interactive=True
306
+ )
307
+ with gr.Row(variant="compact"):
308
+ run_button = gr.Button(
309
+ "πŸš€ Start Processing",
310
+ variant="primary",
311
+ size="lg"
312
+ )
313
+ with gr.Group():
314
+ gr.Markdown("### βš™οΈ Control Parameters")
315
+ seed = gr.Slider(
316
+ label="Random Seed",
317
+ minimum=-1,
318
+ maximum=2147483647,
319
+ value=1234,
320
+ step=1,
321
+ info="-1 for random generation"
322
+ )
323
+ with gr.Column(variant="panel"):
324
+ gr.Markdown("### πŸ–ΌοΈ Example Gallery", elem_classes=["example-title"])
325
+ example = gr.Examples(
326
+ examples=image_examples,
327
+ inputs=[
328
+ gr.Image(label="Image", type="filepath",visible=False),
329
+ gr.Image(label="Mask", type="filepath",visible=False)
330
+ ],
331
+ outputs=[input_image],
332
+ fn=process_example,
333
+ run_on_click=True,
334
+ examples_per_page=10,
335
+ label="Click any example to load",
336
+ elem_id="inline-examples"
337
+ )
338
+
339
+ with gr.Column(scale=1, variant="panel"):
340
+ gr.Markdown("## πŸ“€ Output Panel")
341
+ with gr.Tabs():
342
+ with gr.Tab("Final Result"):
343
+ inpaint_result = gr.Gallery(
344
+ label="Generated Image",
345
+ columns=2,
346
+ height=450,
347
+ preview=True,
348
+ object_fit="contain"
349
+ )
350
 
351
+ with gr.Tab("Visualization Steps"):
352
+ gallery = gr.Gallery(
353
+ label="Workflow Steps",
354
+ columns=2,
355
+ height=450,
356
+ object_fit="contain"
357
+ )
358
 
359
+ run_button.click(
360
+ fn=infer,
361
+ inputs=[
362
+ pipe,
363
+ input_image,
364
+ ddim_steps,
365
+ seed,
366
+ scale,
367
+ removal_prompt,
368
+ ],
369
+ outputs=[inpaint_result, gallery]
370
+ )
371
+
372
 
373
+ demo.launch()