fantos commited on
Commit
b11a5dd
·
verified ·
1 Parent(s): e28a866

Delete src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +0 -300
src/app.py DELETED
@@ -1,300 +0,0 @@
1
- import tempfile
2
- import time
3
- from collections.abc import Sequence
4
- from typing import Any, cast
5
-
6
- import gradio as gr
7
- import numpy as np
8
- import pillow_heif
9
- import spaces
10
- import torch
11
- from gradio_image_annotation import image_annotator
12
- from gradio_imageslider import ImageSlider
13
- from PIL import Image
14
- from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
15
- from refiners.fluxion.utils import no_grad
16
- from refiners.solutions import BoxSegmenter
17
- from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
18
-
19
- BoundingBox = tuple[int, int, int, int]
20
-
21
- pillow_heif.register_heif_opener()
22
- pillow_heif.register_avif_opener()
23
-
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
- # weird dance because ZeroGPU
27
- segmenter = BoxSegmenter(device="cpu")
28
- segmenter.device = device
29
- segmenter.model = segmenter.model.to(device=segmenter.device)
30
-
31
- gd_model_path = "IDEA-Research/grounding-dino-base"
32
- gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
33
- gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
34
- gd_model = gd_model.to(device=device) # type: ignore
35
- assert isinstance(gd_model, GroundingDinoForObjectDetection)
36
-
37
-
38
- def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
39
- if not bboxes:
40
- return None
41
- for bbox in bboxes:
42
- assert len(bbox) == 4
43
- assert all(isinstance(x, int) for x in bbox)
44
- return (
45
- min(bbox[0] for bbox in bboxes),
46
- min(bbox[1] for bbox in bboxes),
47
- max(bbox[2] for bbox in bboxes),
48
- max(bbox[3] for bbox in bboxes),
49
- )
50
-
51
-
52
- def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
53
- x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
54
- return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
55
-
56
-
57
- def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
58
- assert isinstance(gd_processor, GroundingDinoProcessor)
59
-
60
- # Grounding Dino expects a dot after each category.
61
- inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
62
-
63
- with no_grad():
64
- outputs = gd_model(**inputs)
65
- width, height = img.size
66
- results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
67
- outputs,
68
- inputs["input_ids"],
69
- target_sizes=[(height, width)],
70
- )[0]
71
- assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
72
-
73
- bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
74
- return bbox_union(bboxes.numpy().tolist())
75
-
76
-
77
- def apply_mask(
78
- img: Image.Image,
79
- mask_img: Image.Image,
80
- defringe: bool = True,
81
- ) -> Image.Image:
82
- assert img.size == mask_img.size
83
- img = img.convert("RGB")
84
- mask_img = mask_img.convert("L")
85
-
86
- if defringe:
87
- # Mitigate edge halo effects via color decontamination
88
- rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
89
- foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
90
- img = Image.fromarray((foreground * 255).astype("uint8"))
91
-
92
- result = Image.new("RGBA", img.size)
93
- result.paste(img, (0, 0), mask_img)
94
- return result
95
-
96
-
97
- @spaces.GPU
98
- def _gpu_process(
99
- img: Image.Image,
100
- prompt: str | BoundingBox | None,
101
- ) -> tuple[Image.Image, BoundingBox | None, list[str]]:
102
- # Because of ZeroGPU shenanigans, we need a *single* function with the
103
- # `spaces.GPU` decorator that *does not* contain postprocessing.
104
-
105
- time_log: list[str] = []
106
-
107
- if isinstance(prompt, str):
108
- t0 = time.time()
109
- bbox = gd_detect(img, prompt)
110
- time_log.append(f"detect: {time.time() - t0}")
111
- if not bbox:
112
- print(time_log[0])
113
- raise gr.Error("No object detected")
114
- else:
115
- bbox = prompt
116
-
117
- t0 = time.time()
118
- mask = segmenter(img, bbox)
119
- time_log.append(f"segment: {time.time() - t0}")
120
-
121
- return mask, bbox, time_log
122
-
123
-
124
- def _process(
125
- img: Image.Image,
126
- prompt: str | BoundingBox | None,
127
- ) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
128
- # enforce max dimensions for pymatting performance reasons
129
- if img.width > 2048 or img.height > 2048:
130
- orig_res = max(img.width, img.height)
131
- img.thumbnail((2048, 2048))
132
- if isinstance(prompt, tuple):
133
- x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
134
- prompt = (x0, y0, x1, y1)
135
-
136
- mask, bbox, time_log = _gpu_process(img, prompt)
137
-
138
- t0 = time.time()
139
- masked_alpha = apply_mask(img, mask, defringe=True)
140
- time_log.append(f"crop: {time.time() - t0}")
141
- print(", ".join(time_log))
142
-
143
- masked_rgb = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
144
-
145
- thresholded = mask.point(lambda p: 255 if p > 10 else 0)
146
- bbox = thresholded.getbbox()
147
- to_dl = masked_alpha.crop(bbox)
148
-
149
- temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
150
- to_dl.save(temp, format="PNG")
151
- temp.close()
152
-
153
- return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)
154
-
155
-
156
- def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
157
- assert isinstance(img := prompts["image"], Image.Image)
158
- assert isinstance(boxes := prompts["boxes"], list)
159
- if len(boxes) == 1:
160
- assert isinstance(box := boxes[0], dict)
161
- bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
162
- else:
163
- assert len(boxes) == 0
164
- bbox = None
165
- return _process(img, bbox)
166
-
167
-
168
- def on_change_bbox(prompts: dict[str, Any] | None):
169
- return gr.update(interactive=prompts is not None)
170
-
171
-
172
- def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
173
- return _process(img, prompt)
174
-
175
-
176
- def on_change_prompt(img: Image.Image | None, prompt: str | None):
177
- return gr.update(interactive=bool(img and prompt))
178
-
179
-
180
- css = """
181
- footer {
182
- visibility: hidden;
183
- }
184
- """
185
-
186
-
187
- with gr.Blocks(css=css) as demo:
188
-
189
- with gr.Tab("By prompt", id="tab_prompt"):
190
- with gr.Row():
191
- with gr.Column():
192
- iimg = gr.Image(type="pil", label="Input")
193
- prompt = gr.Textbox(label="What should we cut?")
194
- btn = gr.ClearButton(value="Cut Out Object", interactive=False)
195
- with gr.Column():
196
- oimg = ImageSlider(label="Before / After", show_download_button=False, interactive=False)
197
- dlbt = gr.DownloadButton("Download Cutout", interactive=False)
198
-
199
- btn.add(oimg)
200
-
201
- for inp in [iimg, prompt]:
202
- inp.change(
203
- fn=on_change_prompt,
204
- inputs=[iimg, prompt],
205
- outputs=[btn],
206
- )
207
- btn.click(
208
- fn=process_prompt,
209
- inputs=[iimg, prompt],
210
- outputs=[oimg, dlbt],
211
- api_name=False,
212
- )
213
-
214
- examples = [
215
- [
216
- "examples/text.jpg",
217
- "text",
218
- ],
219
- [
220
- "examples/potted-plant.jpg",
221
- "potted plant",
222
- ],
223
- [
224
- "examples/chair.jpg",
225
- "chair",
226
- ],
227
- [
228
- "examples/black-lamp.jpg",
229
- "black lamp",
230
- ],
231
- ]
232
-
233
- ex = gr.Examples(
234
- examples=examples,
235
- inputs=[iimg, prompt],
236
- outputs=[oimg, dlbt],
237
- fn=process_prompt,
238
- cache_examples=True,
239
- )
240
-
241
- with gr.Tab("By bounding box", id="tab_bb"):
242
- with gr.Row():
243
- with gr.Column():
244
- annotator = image_annotator(
245
- image_type="pil",
246
- disable_edit_boxes=True,
247
- show_download_button=False,
248
- show_share_button=False,
249
- single_box=True,
250
- label="Input",
251
- )
252
- btn = gr.ClearButton(value="Cut Out Object", interactive=False)
253
- with gr.Column():
254
- oimg = ImageSlider(label="Before / After", show_download_button=False)
255
- dlbt = gr.DownloadButton("Download Cutout", interactive=False)
256
-
257
- btn.add(oimg)
258
-
259
- annotator.change(
260
- fn=on_change_bbox,
261
- inputs=[annotator],
262
- outputs=[btn],
263
- )
264
- btn.click(
265
- fn=process_bbox,
266
- inputs=[annotator],
267
- outputs=[oimg, dlbt],
268
- api_name=False,
269
- )
270
-
271
- examples = [
272
- {
273
- "image": "examples/text.jpg",
274
- "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
275
- },
276
- {
277
- "image": "examples/potted-plant.jpg",
278
- "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
279
- },
280
- {
281
- "image": "examples/chair.jpg",
282
- "boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
283
- },
284
- {
285
- "image": "examples/black-lamp.jpg",
286
- "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
287
- },
288
- ]
289
-
290
- ex = gr.Examples(
291
- examples=examples,
292
- inputs=[annotator],
293
- outputs=[oimg, dlbt],
294
- fn=process_bbox,
295
- cache_examples=True,
296
- )
297
-
298
-
299
- demo.queue(max_size=30, api_open=False)
300
- demo.launch(show_api=False)