fantos commited on
Commit
5624229
·
verified ·
1 Parent(s): b11a5dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import pillow_heif
9
+ import spaces
10
+ import torch
11
+ from gradio_image_annotation import image_annotator
12
+ from gradio_imageslider import ImageSlider
13
+ from PIL import Image
14
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
15
+ from refiners.fluxion.utils import no_grad
16
+ from refiners.solutions import BoxSegmenter
17
+ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
18
+
19
+ BoundingBox = tuple[int, int, int, int]
20
+
21
+ pillow_heif.register_heif_opener()
22
+ pillow_heif.register_avif_opener()
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # weird dance because ZeroGPU
27
+ segmenter = BoxSegmenter(device="cpu")
28
+ segmenter.device = device
29
+ segmenter.model = segmenter.model.to(device=segmenter.device)
30
+
31
+ gd_model_path = "IDEA-Research/grounding-dino-base"
32
+ gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
33
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
34
+ gd_model = gd_model.to(device=device) # type: ignore
35
+ assert isinstance(gd_model, GroundingDinoForObjectDetection)
36
+
37
+
38
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
39
+ if not bboxes:
40
+ return None
41
+ for bbox in bboxes:
42
+ assert len(bbox) == 4
43
+ assert all(isinstance(x, int) for x in bbox)
44
+ return (
45
+ min(bbox[0] for bbox in bboxes),
46
+ min(bbox[1] for bbox in bboxes),
47
+ max(bbox[2] for bbox in bboxes),
48
+ max(bbox[3] for bbox in bboxes),
49
+ )
50
+
51
+
52
+ def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
53
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
54
+ return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
55
+
56
+
57
+ def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
58
+ assert isinstance(gd_processor, GroundingDinoProcessor)
59
+
60
+ # Grounding Dino expects a dot after each category.
61
+ inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
62
+
63
+ with no_grad():
64
+ outputs = gd_model(**inputs)
65
+ width, height = img.size
66
+ results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
67
+ outputs,
68
+ inputs["input_ids"],
69
+ target_sizes=[(height, width)],
70
+ )[0]
71
+ assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
72
+
73
+ bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
74
+ return bbox_union(bboxes.numpy().tolist())
75
+
76
+
77
+ def apply_mask(
78
+ img: Image.Image,
79
+ mask_img: Image.Image,
80
+ defringe: bool = True,
81
+ ) -> Image.Image:
82
+ assert img.size == mask_img.size
83
+ img = img.convert("RGB")
84
+ mask_img = mask_img.convert("L")
85
+
86
+ if defringe:
87
+ # Mitigate edge halo effects via color decontamination
88
+ rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
89
+ foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
90
+ img = Image.fromarray((foreground * 255).astype("uint8"))
91
+
92
+ result = Image.new("RGBA", img.size)
93
+ result.paste(img, (0, 0), mask_img)
94
+ return result
95
+
96
+
97
+ @spaces.GPU
98
+ def _gpu_process(
99
+ img: Image.Image,
100
+ prompt: str | BoundingBox | None,
101
+ ) -> tuple[Image.Image, BoundingBox | None, list[str]]:
102
+ # Because of ZeroGPU shenanigans, we need a *single* function with the
103
+ # `spaces.GPU` decorator that *does not* contain postprocessing.
104
+
105
+ time_log: list[str] = []
106
+
107
+ if isinstance(prompt, str):
108
+ t0 = time.time()
109
+ bbox = gd_detect(img, prompt)
110
+ time_log.append(f"detect: {time.time() - t0}")
111
+ if not bbox:
112
+ print(time_log[0])
113
+ raise gr.Error("No object detected")
114
+ else:
115
+ bbox = prompt
116
+
117
+ t0 = time.time()
118
+ mask = segmenter(img, bbox)
119
+ time_log.append(f"segment: {time.time() - t0}")
120
+
121
+ return mask, bbox, time_log
122
+
123
+
124
+ def _process(
125
+ img: Image.Image,
126
+ prompt: str | BoundingBox | None,
127
+ ) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
128
+ # enforce max dimensions for pymatting performance reasons
129
+ if img.width > 2048 or img.height > 2048:
130
+ orig_res = max(img.width, img.height)
131
+ img.thumbnail((2048, 2048))
132
+ if isinstance(prompt, tuple):
133
+ x0, y0, x1, y2 = (int(x * 2048 / orig_res) for x in prompt)
134
+ prompt = (x0, y0, x1, y2)
135
+
136
+ mask, bbox, time_log = _gpu_process(img, prompt)
137
+
138
+ t0 = time.time()
139
+ masked_alpha = apply_mask(img, mask, defringe=True)
140
+ time_log.append(f"crop: {time.time() - t0}")
141
+ print(", ".join(time_log))
142
+
143
+ masked_rgb = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
144
+
145
+ thresholded = mask.point(lambda p: 255 if p > 10 else 0)
146
+ bbox = thresholded.getbbox()
147
+ to_dl = masked_alpha.crop(bbox)
148
+
149
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
150
+ to_dl.save(temp, format="PNG")
151
+ temp.close()
152
+
153
+ return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)
154
+
155
+
156
+ def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
157
+ assert isinstance(img := prompts["image"], Image.Image)
158
+ assert isinstance(boxes := prompts["boxes"], list)
159
+ if len(boxes) == 1:
160
+ assert isinstance(box := boxes[0], dict)
161
+ bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
162
+ else:
163
+ assert len(boxes) == 0
164
+ bbox = None
165
+ return _process(img, bbox)
166
+
167
+
168
+ def on_change_bbox(prompts: dict[str, Any] | None):
169
+ return gr.update(interactive=prompts is not None)
170
+
171
+
172
+ def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
173
+ return _process(img, prompt)
174
+
175
+
176
+ def on_change_prompt(img: Image.Image | None, prompt: str | None):
177
+ return gr.update(interactive=bool(img and prompt))
178
+
179
+
180
+ css = """
181
+ footer {
182
+ visibility: hidden;
183
+ }
184
+ """
185
+
186
+
187
+ with gr.Blocks(css=css) as demo:
188
+
189
+ with gr.Tab("By prompt", id="tab_prompt"):
190
+ with gr.Row():
191
+ with gr.Column():
192
+ iimg = gr.Image(type="pil", label="Input")
193
+ prompt = gr.Textbox(label="What should we cut?")
194
+ btn = gr.Button("Cut Out Object", interactive=False) # 수정됨: ClearButton에서 Button으로 변경
195
+ with gr.Column():
196
+ oimg = ImageSlider(label="Before / After", show_download_button=False, interactive=False)
197
+ dlbt = gr.DownloadButton("Download Cutout", interactive=False)
198
+
199
+ btn.add(oimg)
200
+
201
+ for inp in [iimg, prompt]:
202
+ inp.change(
203
+ fn=on_change_prompt,
204
+ inputs=[iimg, prompt],
205
+ outputs=[btn],
206
+ )
207
+ btn.click(
208
+ fn=process_prompt,
209
+ inputs=[iimg, prompt],
210
+ outputs=[oimg, dlbt],
211
+ api_name=False,
212
+ )
213
+
214
+ examples = [
215
+ [
216
+ "examples/text.jpg",
217
+ "text",
218
+ ],
219
+ [
220
+ "examples/potted-plant.jpg",
221
+ "potted plant",
222
+ ],
223
+ [
224
+ "examples/chair.jpg",
225
+ "chair",
226
+ ],
227
+ [
228
+ "examples/black-lamp.jpg",
229
+ "black lamp",
230
+ ],
231
+ ]
232
+
233
+ ex = gr.Examples(
234
+ examples=examples,
235
+ inputs=[iimg, prompt],
236
+ outputs=[oimg, dlbt],
237
+ fn=process_prompt,
238
+ cache_examples=True,
239
+ )
240
+
241
+ with gr.Tab("By bounding box", id="tab_bb"):
242
+ with gr.Row():
243
+ with gr.Column():
244
+ annotator = image_annotator(
245
+ image_type="pil",
246
+ disable_edit_boxes=True,
247
+ show_download_button=False,
248
+ show_share_button=False,
249
+ single_box=True,
250
+ label="Input",
251
+ )
252
+ btn = gr.Button("Cut Out Object", interactive=False) # 수정됨: ClearButton에서 Button으로 변경
253
+ with gr.Column():
254
+ oimg = ImageSlider(label="Before / After", show_download_button=False)
255
+ dlbt = gr.DownloadButton("Download Cutout", interactive=False)
256
+
257
+ btn.add(oimg)
258
+
259
+ annotator.change(
260
+ fn=on_change_bbox,
261
+ inputs=[annotator],
262
+ outputs=[btn],
263
+ )
264
+ btn.click(
265
+ fn=process_bbox,
266
+ inputs=[annotator],
267
+ outputs=[oimg, dlbt],
268
+ api_name=False,
269
+ )
270
+
271
+ examples = [
272
+ {
273
+ "image": "examples/text.jpg",
274
+ "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
275
+ },
276
+ {
277
+ "image": "examples/potted-plant.jpg",
278
+ "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
279
+ },
280
+ {
281
+ "image": "examples/chair.jpg",
282
+ "boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
283
+ },
284
+ {
285
+ "image": "examples/black-lamp.jpg",
286
+ "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
287
+ },
288
+ ]
289
+
290
+ ex = gr.Examples(
291
+ examples=examples,
292
+ inputs=[annotator],
293
+ outputs=[oimg, dlbt],
294
+ fn=process_bbox,
295
+ cache_examples=True,
296
+ )
297
+
298
+
299
+ demo.queue(max_size=30, api_open=False)
300
+ demo.launch(show_api=False)