theSure commited on
Commit
373d8e0
Β·
verified Β·
1 Parent(s): a49cc2f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -0
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ import torch
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+
11
+ from PIL import Image, ImageCms
12
+ import torch
13
+ from diffusers import FluxTransformer2DModel
14
+ from diffusers.utils import load_image
15
+ from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
+
17
+ torch.set_grad_enabled(False)
18
+ os.environ['GRADIO_TEMP_DIR'] = './tmp'
19
+
20
+ image_path = mask_path = None
21
+ image_examples = [...]
22
+ image_path = mask_path =None
23
+ image_examples = [
24
+ [
25
+ "example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
26
+ "example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
27
+ ],
28
+ [
29
+ "example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
30
+ "example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
31
+ ],
32
+ [
33
+ "example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
34
+ "example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
35
+ ],
36
+ [
37
+ "example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
38
+ "example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
39
+ ],
40
+ [
41
+ "example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
42
+ "example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
43
+ ],
44
+ [
45
+ "example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
46
+ "example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
47
+ ],
48
+ [
49
+ "example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
50
+ "example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
51
+ ]
52
+
53
+ ]
54
+
55
+ def load_model(base_model_path, lora_path):
56
+ global pipe
57
+ transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
58
+ gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
59
+ # enable image inputs
60
+ with torch.no_grad():
61
+ initial_input_channels = transformer.config.in_channels
62
+ new_linear = torch.nn.Linear(
63
+ transformer.x_embedder.in_features*4,
64
+ transformer.x_embedder.out_features,
65
+ bias=transformer.x_embedder.bias is not None,
66
+ dtype=transformer.dtype,
67
+ device=transformer.device,
68
+ )
69
+ new_linear.weight.zero_()
70
+ new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight)
71
+
72
+ if transformer.x_embedder.bias is not None:
73
+ new_linear.bias.copy_(transformer.x_embedder.bias)
74
+
75
+ transformer.x_embedder = new_linear
76
+ transformer.register_to_config(in_channels=initial_input_channels*4)
77
+
78
+ pipe = FluxControlRemovalPipeline.from_pretrained(
79
+ base_model_path,
80
+ transformer=transformer,
81
+ torch_dtype=torch.bfloat16
82
+ ).to("cuda")
83
+ pipe.transformer.to(torch.bfloat16)
84
+ gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
85
+ gr.Info(str(f"Inject LoRA: {lora_path}"))
86
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
87
+ gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
88
+
89
+ def set_seed(seed):
90
+ torch.manual_seed(seed)
91
+ torch.cuda.manual_seed(seed)
92
+ torch.cuda.manual_seed_all(seed)
93
+ np.random.seed(seed)
94
+ random.seed(seed)
95
+
96
+
97
+ def predict(
98
+ input_image,
99
+ prompt,
100
+ ddim_steps,
101
+ seed,
102
+ scale,
103
+ image_paths,
104
+ mask_paths
105
+ ):
106
+ global image_path, mask_path
107
+ gr.Info(str(f"Set seed = {seed}"))
108
+ if image_paths is not None:
109
+ input_image["image"] = load_image(image_paths).convert("RGB")
110
+ input_image["mask"] = load_image(mask_paths).convert("RGB")
111
+
112
+ size1, size2 = input_image["image"].convert("RGB").size
113
+
114
+ icc_profile = input_image["image"].info.get('icc_profile')
115
+ if icc_profile:
116
+ gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
117
+ srgb_profile = ImageCms.createProfile("sRGB")
118
+ io_handle = io.BytesIO(icc_profile)
119
+ src_profile = ImageCms.ImageCmsProfile(io_handle)
120
+ input_image["image"] = ImageCms.profileToProfile(input_image["image"], src_profile, srgb_profile)
121
+ input_image["image"].info.pop('icc_profile', None)
122
+
123
+ if size1 < size2:
124
+ input_image["image"] = input_image["image"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
125
+ else:
126
+ input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
127
+
128
+ img = np.array(input_image["image"].convert("RGB"))
129
+
130
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
131
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
132
+
133
+ input_image["image"] = input_image["image"].resize((H, W))
134
+ input_image["mask"] = input_image["mask"].resize((H, W))
135
+
136
+ if seed == -1:
137
+ seed = random.randint(1, 2147483647)
138
+ set_seed(random.randint(1, 2147483647))
139
+ else:
140
+ set_seed(seed)
141
+
142
+
143
+ result = pipe(
144
+ prompt=prompt,
145
+ control_image=input_image["image"].convert("RGB"),
146
+ control_mask=input_image["mask"].convert("RGB"),
147
+ width=H,
148
+ height=W,
149
+ num_inference_steps=ddim_steps,
150
+ generator=torch.Generator("cuda").manual_seed(seed),
151
+ guidance_scale=scale,
152
+ max_sequence_length=512,
153
+ ).images[0]
154
+
155
+ mask_np = np.array(input_image["mask"].convert("RGB"))
156
+ red = np.array(input_image["image"]).astype("float") * 1
157
+ red[:, :, 0] = 180.0
158
+ red[:, :, 2] = 0
159
+ red[:, :, 1] = 0
160
+ result_m = np.array(input_image["image"])
161
+ result_m = Image.fromarray(
162
+ (
163
+ result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
164
+ ).astype("uint8")
165
+ )
166
+
167
+ dict_res = [input_image["image"], input_image["mask"], result_m, result]
168
+
169
+ dict_out = [result]
170
+ image_path = None
171
+ mask_path = None
172
+ return dict_out, dict_res
173
+
174
+
175
+ def infer(
176
+ input_image,
177
+ ddim_steps,
178
+ seed,
179
+ scale,
180
+ removal_prompt,
181
+ ):
182
+ img_path = image_path
183
+ msk_path = mask_path
184
+ return predict(input_image,
185
+ removal_prompt,
186
+ ddim_steps,
187
+ seed,
188
+ scale,
189
+ img_path,
190
+ msk_path
191
+ )
192
+
193
+ def process_example(image_paths, mask_paths):
194
+ global image_path, mask_path
195
+ image = Image.open(image_paths).convert("RGB")
196
+ mask = Image.open(mask_paths).convert("L")
197
+ black_background = Image.new("RGB", image.size, (0, 0, 0))
198
+ masked_image = Image.composite(black_background, image, mask)
199
+
200
+ image_path = image_paths
201
+ mask_path = mask_paths
202
+ return masked_image
203
+ custom_css = """
204
+ .contain { max-width: 1200px !important; }
205
+ .custom-image {
206
+ border: 2px dashed #7e22ce !important;
207
+ border-radius: 12px !important;
208
+ transition: all 0.3s ease !important;
209
+ }
210
+ .custom-image:hover {
211
+ border-color: #9333ea !important;
212
+ box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
213
+ }
214
+ .btn-primary {
215
+ background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
216
+ border: none !important;
217
+ color: white !important;
218
+ border-radius: 8px !important;
219
+ }
220
+ #inline-examples {
221
+ border: 1px solid #e2e8f0 !important;
222
+ border-radius: 12px !important;
223
+ padding: 16px !important;
224
+ margin-top: 8px !important;
225
+ }
226
+ #inline-examples .thumbnail {
227
+ border-radius: 8px !important;
228
+ transition: transform 0.2s ease !important;
229
+ }
230
+ #inline-examples .thumbnail:hover {
231
+ transform: scale(1.05);
232
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
233
+ }
234
+ .example-title h3 {
235
+ margin: 0 0 12px 0 !important;
236
+ color: #475569 !important;
237
+ font-size: 1.1em !important;
238
+ display: flex !important;
239
+ align-items: center !important;
240
+ }
241
+ .example-title h3::before {
242
+ content: "πŸ“š";
243
+ margin-right: 8px;
244
+ font-size: 1.2em;
245
+ }
246
+ """
247
+
248
+ with gr.Blocks(
249
+ css=custom_css,
250
+ theme=gr.themes.Soft(
251
+ primary_hue="purple",
252
+ secondary_hue="purple",
253
+ font=[gr.themes.GoogleFont('Inter'), 'sans-serif']
254
+ ),
255
+ title="Omnieraser"
256
+ ) as demo:
257
+ base_model_path = "black-forest-labs/FLUX.1-dev"
258
+ lora_path = 'theSure/Omnieraser'
259
+ load_model(base_model_path=base_model_path, lora_path=lora_path)
260
+
261
+ ddim_steps = gr.Slider(visible=False, value=28)
262
+ scale = gr.Slider(visible=False, value=3.5)
263
+ seed = gr.Slider(visible=False, value=-1)
264
+ removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
265
+
266
+ gr.Markdown("""
267
+ <div align="center">
268
+ <h1 style="font-size: 2.5em; margin-bottom: 0.5em;">πŸͺ„ Omnieraser</h1>
269
+ </div>
270
+ """)
271
+
272
+ with gr.Row(equal_height=True):
273
+ with gr.Column(scale=1, variant="panel"):
274
+ gr.Markdown("## πŸ“₯ Input Panel")
275
+
276
+ with gr.Group(border=True):
277
+ input_image = gr.Image(
278
+ source="upload",
279
+ tool="sketch",
280
+ type="pil",
281
+ label="Upload & Annotate",
282
+ height=400,
283
+ elem_id="custom-image",
284
+ interactive=True
285
+ )
286
+ with gr.Row(variant="compact"):
287
+ run_button = gr.Button(
288
+ "πŸš€ Start Processing",
289
+ variant="primary",
290
+ size="lg"
291
+ )
292
+ with gr.Group(border=True):
293
+ gr.Markdown("### βš™οΈ Control Parameters")
294
+ seed = gr.Slider(
295
+ label="Random Seed",
296
+ minimum=-1,
297
+ maximum=2147483647,
298
+ value=1234,
299
+ step=1,
300
+ info="-1 for random generation"
301
+ )
302
+ with gr.Column(variant="panel"):
303
+ gr.Markdown("### πŸ–ΌοΈ Example Gallery", elem_classes=["example-title"])
304
+ example = gr.Examples(
305
+ examples=image_examples,
306
+ inputs=[
307
+ gr.Image(label="Image", type="filepath",visible=False),
308
+ gr.Image(label="Mask", type="filepath",visible=False)
309
+ ],
310
+ outputs=[input_image],
311
+ fn=process_example,
312
+ run_on_click=True,
313
+ examples_per_page=10,
314
+ label="Click any example to load",
315
+ elem_id="inline-examples"
316
+ )
317
+
318
+ with gr.Column(scale=1, variant="panel"):
319
+ gr.Markdown("## πŸ“€ Output Panel")
320
+ with gr.Tabs():
321
+ with gr.Tab("Final Result"):
322
+ inpaint_result = gr.Gallery(
323
+ label="Generated Image",
324
+ columns=2,
325
+ height=450,
326
+ preview=True,
327
+ object_fit="contain"
328
+ )
329
+
330
+ with gr.Tab("Visualization Steps"):
331
+ gallery = gr.Gallery(
332
+ label="Workflow Steps",
333
+ columns=2,
334
+ height=450,
335
+ object_fit="contain"
336
+ )
337
+
338
+ run_button.click(
339
+ fn=infer,
340
+ inputs=[
341
+ input_image,
342
+ ddim_steps,
343
+ seed,
344
+ scale,
345
+ removal_prompt,
346
+ ],
347
+ outputs=[inpaint_result, gallery]
348
+ )
349
+
350
+
351
+ demo.launch()