theSure commited on
Commit
cc07737
·
verified ·
1 Parent(s): f1243f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -352
app.py CHANGED
@@ -1,368 +1,93 @@
1
- import io
2
- import os
3
- import shutil
4
- import uuid
5
  import torch
6
- import random
7
  import spaces
8
- import gradio as gr
9
  import numpy as np
 
 
10
 
11
- from PIL import Image, ImageCms
12
- import torch
13
- from diffusers import FluxTransformer2DModel
14
- from diffusers.utils import load_image
15
- from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
-
17
- torch.set_grad_enabled(False)
18
- image_path = mask_path = None
19
- image_examples = [...]
20
- image_path = mask_path =None
21
- image_examples = [
22
- [
23
- "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
24
- "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
25
- ],
26
- [
27
- "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
28
- "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
29
- ],
30
- [
31
- "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
32
- "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
33
- ],
34
- [
35
- "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
36
- "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
37
- ],
38
- [
39
- "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
40
- "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
41
- ],
42
- [
43
- "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
44
- "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
45
- ],
46
- [
47
- "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
48
- "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
49
- ]
50
-
51
- ]
52
-
53
- @spaces.GPU(enable_queue=True)
54
- def load_model(base_model_path, lora_path):
55
- global pipe
56
- transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
57
- gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
58
- # enable image inputs
59
- with torch.no_grad():
60
- initial_input_channels = transformer.config.in_channels
61
- new_linear = torch.nn.Linear(
62
- transformer.x_embedder.in_features*4,
63
- transformer.x_embedder.out_features,
64
- bias=transformer.x_embedder.bias is not None,
65
- dtype=transformer.dtype,
66
- device=transformer.device,
67
- )
68
- new_linear.weight.zero_()
69
- new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight)
70
-
71
- if transformer.x_embedder.bias is not None:
72
- new_linear.bias.copy_(transformer.x_embedder.bias)
73
 
74
- transformer.x_embedder = new_linear
75
- transformer.register_to_config(in_channels=initial_input_channels*4)
76
 
77
- pipe = FluxControlRemovalPipeline.from_pretrained(
78
- base_model_path,
79
- transformer=transformer,
80
- torch_dtype=torch.bfloat16
 
81
  ).to("cuda")
82
- pipe.transformer.to(torch.bfloat16)
83
- gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
84
- gr.Info(str(f"Inject LoRA: {lora_path}"))
85
- pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
86
- gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
87
- @spaces.GPU(enable_queue=True)
88
- def set_seed(seed):
89
- torch.manual_seed(seed)
90
- torch.cuda.manual_seed(seed)
91
- torch.cuda.manual_seed_all(seed)
92
- np.random.seed(seed)
93
- random.seed(seed)
94
-
95
- @spaces.GPU(enable_queue=True)
96
- def predict(
97
- input_image,
98
- prompt,
99
- ddim_steps,
100
- seed,
101
- scale,
102
- image_paths,
103
- mask_paths
104
-
105
- ):
106
- global image_path, mask_path
107
- gr.Info(str(f"Set seed = {seed}"))
108
- if image_paths is not None:
109
- input_image["background"] = load_image(image_paths).convert("RGB")
110
- input_image["layers"][0] = load_image(mask_paths).convert("RGB")
111
-
112
- size1, size2 = input_image["background"].convert("RGB").size
113
- icc_profile = input_image["background"].info.get('icc_profile')
114
- if icc_profile:
115
- gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
116
- srgb_profile = ImageCms.createProfile("sRGB")
117
- io_handle = io.BytesIO(icc_profile)
118
- src_profile = ImageCms.ImageCmsProfile(io_handle)
119
- input_image["background"] = ImageCms.profileToProfile(input_image["background"], src_profile, srgb_profile)
120
- input_image["background"].info.pop('icc_profile', None)
121
-
122
- if size1 < size2:
123
- input_image["background"] = input_image["background"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
124
- else:
125
- input_image["background"] = input_image["background"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
126
-
127
- img = np.array(input_image["background"].convert("RGB"))
128
-
129
- W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
130
- H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
131
-
132
- input_image["background"] = input_image["background"].resize((H, W))
133
- input_image["layers"][0] = input_image["layers"][0].resize((H, W))
134
-
135
- if seed == -1:
136
- seed = random.randint(1, 2147483647)
137
- set_seed(random.randint(1, 2147483647))
138
- else:
139
- set_seed(seed)
140
- if image_paths is None:
141
- img=input_image["layers"][0]
142
- img_data = np.array(img)
143
- alpha_channel = img_data[:, :, 3]
144
- white_background = np.ones_like(alpha_channel) * 255
145
- gray_image = white_background.copy()
146
- gray_image[alpha_channel == 0] = 0
147
- gray_image_pil = Image.fromarray(gray_image).convert('L')
148
- else:
149
- gray_image_pil = input_image["layers"][0]
150
- result = pipe(
151
- prompt=prompt,
152
- control_image=input_image["background"].convert("RGB"),
153
- control_mask=gray_image_pil.convert("RGB"),
154
- width=H,
155
- height=W,
156
- num_inference_steps=ddim_steps,
157
- generator=torch.Generator("cuda").manual_seed(seed),
158
- guidance_scale=scale,
159
- max_sequence_length=512,
160
- ).images[0]
161
-
162
- mask_np = np.array(input_image["layers"][0].convert("RGB"))
163
- red = np.array(input_image["background"]).astype("float") * 1
164
- red[:, :, 0] = 180.0
165
- red[:, :, 2] = 0
166
- red[:, :, 1] = 0
167
- result_m = np.array(input_image["background"])
168
- result_m = Image.fromarray(
169
- (
170
- result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
171
- ).astype("uint8")
172
- )
173
-
174
- dict_res = [input_image["background"], input_image["layers"][0], result_m, result]
175
-
176
- dict_out = [result]
177
- image_path = None
178
- mask_path = None
179
- return dict_out, dict_res
180
-
181
-
182
- def infer(
183
- input_image,
184
- ddim_steps,
185
- seed,
186
- scale,
187
- removal_prompt,
188
-
189
- ):
190
- img_path = image_path
191
- msk_path = mask_path
192
- return predict(input_image,
193
- removal_prompt,
194
- ddim_steps,
195
- seed,
196
- scale,
197
- img_path,
198
- msk_path
199
- )
200
-
201
- def process_example(image_paths, mask_paths):
202
- global image_path, mask_path
203
- image = Image.open(image_paths).convert("RGB")
204
- mask = Image.open(mask_paths).convert("L")
205
- black_background = Image.new("RGB", image.size, (0, 0, 0))
206
- masked_image = Image.composite(black_background, image, mask)
207
-
208
- image_path = image_paths
209
- mask_path = mask_paths
210
- return masked_image
211
- custom_css = """
212
-
213
- .contain { max-width: 1200px !important; }
214
-
215
- .custom-image {
216
- border: 2px dashed #7e22ce !important;
217
- border-radius: 12px !important;
218
- transition: all 0.3s ease !important;
219
- }
220
- .custom-image:hover {
221
- border-color: #9333ea !important;
222
- box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
223
- }
224
-
225
- .btn-primary {
226
- background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
227
- border: none !important;
228
- color: white !important;
229
- border-radius: 8px !important;
230
- }
231
- #inline-examples {
232
- border: 1px solid #e2e8f0 !important;
233
- border-radius: 12px !important;
234
- padding: 16px !important;
235
- margin-top: 8px !important;
236
- }
237
-
238
- #inline-examples .thumbnail {
239
- border-radius: 8px !important;
240
- transition: transform 0.2s ease !important;
241
- }
242
-
243
- #inline-examples .thumbnail:hover {
244
- transform: scale(1.05);
245
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
246
- }
247
-
248
- .example-title h3 {
249
- margin: 0 0 12px 0 !important;
250
- color: #475569 !important;
251
- font-size: 1.1em !important;
252
- display: flex !important;
253
- align-items: center !important;
254
- }
255
-
256
- .example-title h3::before {
257
- content: "📚";
258
- margin-right: 8px;
259
- font-size: 1.2em;
260
- }
261
 
262
- .row { align-items: stretch !important; }
 
 
263
 
264
- .panel { height: 100%; }
265
- """
 
 
266
 
267
- with gr.Blocks(
268
- css=custom_css,
269
- theme=gr.themes.Soft(
270
- primary_hue="purple",
271
- secondary_hue="purple",
272
- font=[gr.themes.GoogleFont('Inter'), 'sans-serif']
273
- ),
274
- title="Omnieraser"
275
- ) as demo:
276
- base_model_path = 'black-forest-labs/FLUX.1-dev'
277
- lora_path = 'theSure/Omnieraser'
278
- load_model(base_model_path=base_model_path, lora_path=lora_path)
279
 
280
- ddim_steps = gr.Slider(visible=False, value=28)
281
- scale = gr.Slider(visible=False, value=3.5)
282
- seed = gr.Slider(visible=False, value=-1)
283
- removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
284
 
285
- gr.Markdown("""
286
- <div align="center">
287
- <h1 style="font-size: 2.5em; margin-bottom: 0.5em;">🪄 Omnieraser</h1>
288
- </div>
289
- """)
290
 
291
- with gr.Row(equal_height=False):
292
- with gr.Column(scale=1, variant="panel"):
293
- gr.Markdown("## 📥 Input Panel")
294
-
295
- with gr.Group():
296
- input_image = gr.Sketchpad(
297
- sources=["upload"],
298
- type="pil",
299
- label="Upload & Annotate",
300
- elem_id="custom-image",
301
- interactive=True
302
- )
303
- with gr.Row(variant="compact"):
304
- run_button = gr.Button(
305
- "🚀 Start Processing",
306
- variant="primary",
307
- size="lg"
308
- )
309
- with gr.Group():
310
- gr.Markdown("### ⚙️ Control Parameters")
311
- seed = gr.Slider(
312
- label="Random Seed",
313
- minimum=-1,
314
- maximum=2147483647,
315
- value=1234,
316
- step=1,
317
- info="-1 for random generation"
318
- )
319
- with gr.Column(variant="panel"):
320
- gr.Markdown("### 🖼️ Example Gallery", elem_classes=["example-title"])
321
- example = gr.Examples(
322
- examples=image_examples,
323
- inputs=[
324
- gr.Image(label="Image", type="filepath",visible=False),
325
- gr.Image(label="Mask", type="filepath",visible=False)
326
- ],
327
- outputs=[input_image],
328
- fn=process_example,
329
- run_on_click=True,
330
- examples_per_page=10,
331
- label="Click any example to load",
332
- elem_id="inline-examples"
333
- )
334
 
335
- with gr.Column(scale=1, variant="panel"):
336
- gr.Markdown("## 📤 Output Panel")
337
- with gr.Tabs():
338
- with gr.Tab("Final Result"):
339
- inpaint_result = gr.Gallery(
340
- label="Generated Image",
341
- columns=2,
342
- height=450,
343
- preview=True,
344
- object_fit="contain"
345
- )
346
 
347
- with gr.Tab("Visualization Steps"):
348
- gallery = gr.Gallery(
349
- label="Workflow Steps",
350
- columns=2,
351
- height=450,
352
- object_fit="contain"
353
- )
354
 
355
- run_button.click(
356
- fn=infer,
357
- inputs=[
358
- input_image,
359
- ddim_steps,
360
- seed,
361
- scale,
362
- removal_prompt,
363
- ],
364
- outputs=[inpaint_result, gallery]
365
- )
366
-
367
- if __name__ == '__main__':
368
- demo.launch()
 
1
+ import gradio as gr
 
 
 
2
  import torch
3
+ from diffusers import StableDiffusionInpaintPipeline
4
  import spaces
5
+ from PIL import Image
6
  import numpy as np
7
+ import random
8
+ import os
9
 
10
+ DESCRIPTION = "# Omnieraser\nRemove anything from any image using the [FLUX](https://huggingface.co/lllyasviel/flux) model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ model_id = "lllyasviel/flux"
13
+ lora_weights = "lllyasviel/flux-inpainting-internal"
14
 
15
+ def load_pipeline():
16
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.float16,
19
+ variant="fp16"
20
  ).to("cuda")
21
+ pipe.load_lora_weights(lora_weights)
22
+ return pipe
23
+
24
+ def inference(pipe, image, mask):
25
+ image = image.convert("RGB").resize((512, 512))
26
+ mask = mask.convert("RGB").resize((512, 512))
27
+
28
+ generator = torch.Generator("cuda").manual_seed(random.randint(0, 999999))
29
+ image = pipe(prompt="", image=image, mask_image=mask, guidance_scale=7.5, generator=generator).images[0]
30
+ return image
31
+
32
+ def process_example(example, pipe):
33
+ image_path, mask_path = example
34
+ image = Image.open(image_path).convert("RGB")
35
+ mask = Image.open(mask_path).convert("RGB")
36
+ return inference(pipe, image, mask)
37
+
38
+ def get_random_examples(dataset_dir="examples"):
39
+ image_dir = os.path.join(dataset_dir, "images")
40
+ mask_dir = os.path.join(dataset_dir, "masks")
41
+ files = os.listdir(image_dir)
42
+ random.shuffle(files)
43
+ examples = [
44
+ [os.path.join(image_dir, f), os.path.join(mask_dir, f)] for f in files if os.path.exists(os.path.join(mask_dir, f))
45
+ ]
46
+ return examples[:30]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def build_ui(pipe):
49
+ with gr.Blocks(css="style.css") as demo:
50
+ gr.Markdown(DESCRIPTION)
51
 
52
+ with gr.Row():
53
+ with gr.Column():
54
+ input_image = gr.Image(label="Input", type="pil")
55
+ mask_image = gr.Image(label="Mask", type="pil")
56
 
57
+ with gr.Row():
58
+ submit = gr.Button("Run", elem_id="submit-button")
 
 
 
 
 
 
 
 
 
 
59
 
60
+ with gr.Column():
61
+ result_image = gr.Image(label="Output")
 
 
62
 
63
+ submit.click(
64
+ fn=lambda img, msk: inference(pipe, img, msk),
65
+ inputs=[input_image, mask_image],
66
+ outputs=result_image
67
+ )
68
 
69
+ gr.Markdown("## Examples")
70
+
71
+ image_examples = get_random_examples()
72
+ example = gr.Examples(
73
+ examples=image_examples,
74
+ inputs=[
75
+ gr.Image(label="Image", type="filepath", visible=False),
76
+ gr.Image(label="Mask", type="filepath", visible=False)
77
+ ],
78
+ outputs=[input_image],
79
+ fn=lambda example: process_example(example, pipe),
80
+ run_on_click=True,
81
+ label="Click any example to load",
82
+ elem_id="inline-examples"
83
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ gr.Markdown("Try drawing over objects you want to remove.")
 
 
 
 
 
 
 
 
 
 
86
 
87
+ return demo
 
 
 
 
 
 
88
 
89
+ # 加载 pipe 并运行 UI
90
+ if __name__ == "__main__":
91
+ pipe = load_pipeline()
92
+ demo = build_ui(pipe)
93
+ demo.launch()