imjunaidafzal xvjiarui commited on
Commit
e53ae56
·
0 Parent(s):

Duplicate from xvjiarui/ODISE

Browse files

Co-authored-by: Jiarui Xu <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ODISE
3
+ emoji: 🤗
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: true
10
+ duplicated_from: xvjiarui/ODISE
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
3
+ #
4
+ # This work is made available under the Nvidia Source Code License.
5
+ # To view a copy of this license, visit
6
+ # https://github.com/NVlabs/ODISE/blob/main/LICENSE
7
+ #
8
+ # Written by Jiarui Xu
9
+ # ------------------------------------------------------------------------------
10
+
11
+ import os
12
+ os.system("pip install git+https://github.com/NVlabs/ODISE.git")
13
+ os.system("pip freeze")
14
+
15
+ import itertools
16
+ import json
17
+ from contextlib import ExitStack
18
+ import gradio as gr
19
+ import numpy as np
20
+ import matplotlib.colors as mplc
21
+ import torch
22
+ from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES
23
+ from PIL import Image
24
+ from torch.cuda.amp import autocast
25
+
26
+ from detectron2.config import instantiate
27
+ from detectron2.data import MetadataCatalog
28
+ from detectron2.data import detection_utils as utils
29
+ from detectron2.data import transforms as T
30
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
31
+ from detectron2.evaluation import inference_context
32
+ from detectron2.utils.env import seed_all_rng
33
+ from detectron2.utils.logger import setup_logger
34
+ from detectron2.utils.visualizer import ColorMode, Visualizer as _Visualizer, random_color
35
+
36
+ from odise import model_zoo
37
+ from odise.checkpoint import ODISECheckpointer
38
+ from odise.config import instantiate_odise
39
+ from odise.data import get_openseg_labels
40
+ from odise.modeling.wrapper import OpenPanopticInference
41
+
42
+ setup_logger()
43
+ logger = setup_logger(name="odise")
44
+
45
+ COCO_THING_CLASSES = [
46
+ label
47
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
48
+ if COCO_CATEGORIES[idx]["isthing"] == 1
49
+ ]
50
+ COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1]
51
+ COCO_STUFF_CLASSES = [
52
+ label
53
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
54
+ if COCO_CATEGORIES[idx]["isthing"] == 0
55
+ ]
56
+ COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0]
57
+
58
+ ADE_THING_CLASSES = [
59
+ label
60
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
61
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 1
62
+ ]
63
+ ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1]
64
+ ADE_STUFF_CLASSES = [
65
+ label
66
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
67
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 0
68
+ ]
69
+ ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0]
70
+
71
+ LVIS_CLASSES = get_openseg_labels("lvis_1203", True)
72
+ # use beautiful coco colors
73
+ LVIS_COLORS = list(
74
+ itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES))
75
+ )
76
+
77
+ class Visualizer(_Visualizer):
78
+
79
+ def draw_text(
80
+ self,
81
+ text,
82
+ position,
83
+ *,
84
+ font_size=None,
85
+ color="g",
86
+ horizontal_alignment="center",
87
+ rotation=0,
88
+ ):
89
+ """
90
+ Args:
91
+ text (str): class label
92
+ position (tuple): a tuple of the x and y coordinates to place text on image.
93
+ font_size (int, optional): font of the text. If not provided, a font size
94
+ proportional to the image width is calculated and used.
95
+ color: color of the text. Refer to `matplotlib.colors` for full list
96
+ of formats that are accepted.
97
+ horizontal_alignment (str): see `matplotlib.text.Text`
98
+ rotation: rotation angle in degrees CCW
99
+
100
+ Returns:
101
+ output (VisImage): image object with text drawn.
102
+ """
103
+ if not font_size:
104
+ font_size = self._default_font_size
105
+
106
+ # since the text background is dark, we don't want the text to be dark
107
+ color = np.clip(color, 0, 1).tolist()
108
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
109
+ color[np.argmax(color)] = max(0.8, np.max(color))
110
+
111
+ x, y = position
112
+ self.output.ax.text(
113
+ x,
114
+ y,
115
+ text,
116
+ size=font_size * self.output.scale,
117
+ family="sans-serif",
118
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
119
+ verticalalignment="top",
120
+ horizontalalignment=horizontal_alignment,
121
+ color=color,
122
+ zorder=10,
123
+ rotation=rotation,
124
+ )
125
+ return self.output
126
+
127
+ class VisualizationDemo(object):
128
+ def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE):
129
+ """
130
+ Args:
131
+ model (nn.Module):
132
+ metadata (MetadataCatalog): image metadata.
133
+ instance_mode (ColorMode):
134
+ parallel (bool): whether to run the model in different processes from visualization.
135
+ Useful since the visualization logic can be slow.
136
+ """
137
+ self.model = model
138
+ self.metadata = metadata
139
+ self.aug = aug
140
+ self.cpu_device = torch.device("cpu")
141
+ self.instance_mode = instance_mode
142
+
143
+ def predict(self, original_image):
144
+ """
145
+ Args:
146
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
147
+
148
+ Returns:
149
+ predictions (dict):
150
+ the output of the model for one image only.
151
+ See :doc:`/tutorials/models` for details about the format.
152
+ """
153
+ height, width = original_image.shape[:2]
154
+ aug_input = T.AugInput(original_image, sem_seg=None)
155
+ self.aug(aug_input)
156
+ image = aug_input.image
157
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
158
+
159
+ inputs = {"image": image, "height": height, "width": width}
160
+ logger.info("forwarding")
161
+ with autocast():
162
+ predictions = self.model([inputs])[0]
163
+ logger.info("done")
164
+ return predictions
165
+
166
+ def run_on_image(self, image):
167
+ """
168
+ Args:
169
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
170
+ This is the format used by OpenCV.
171
+ Returns:
172
+ predictions (dict): the output of the model.
173
+ vis_output (VisImage): the visualized image output.
174
+ """
175
+ vis_output = None
176
+ predictions = self.predict(image)
177
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
178
+ if "panoptic_seg" in predictions:
179
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
180
+ vis_output = visualizer.draw_panoptic_seg(
181
+ panoptic_seg.to(self.cpu_device), segments_info
182
+ )
183
+ else:
184
+ if "sem_seg" in predictions:
185
+ vis_output = visualizer.draw_sem_seg(
186
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
187
+ )
188
+ if "instances" in predictions:
189
+ instances = predictions["instances"].to(self.cpu_device)
190
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
191
+
192
+ return predictions, vis_output
193
+
194
+
195
+ cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True)
196
+
197
+ cfg.model.overlap_threshold = 0
198
+ cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu"
199
+ seed_all_rng(42)
200
+
201
+ dataset_cfg = cfg.dataloader.test
202
+ wrapper_cfg = cfg.dataloader.wrapper
203
+
204
+ aug = instantiate(dataset_cfg.mapper).augmentations
205
+
206
+ model = instantiate_odise(cfg.model)
207
+ model.to(torch.float16)
208
+ model.to(cfg.train.device)
209
+ ODISECheckpointer(model).load(cfg.train.init_checkpoint)
210
+
211
+
212
+ title = "ODISE"
213
+ description = """
214
+ <p style='text-align: center'> <a href='https://jerryxu.net/ODISE' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.04803' target='_blank'>Paper</a> | <a href='https://github.com/NVlabs/ODISE' target='_blank'>Code</a> | <a href='https://youtu.be/Su7p5KYmcII' target='_blank'>Video</a></p>
215
+
216
+ Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n
217
+ You may click on of the examples or upload your own image. \n
218
+
219
+ ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma).
220
+ The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class.
221
+ The first word will be displayed as the class name.
222
+ """ # noqa
223
+
224
+ article = """
225
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p>
226
+ """ # noqa
227
+
228
+ examples = [
229
+ [
230
+ "demo/examples/coco.jpg",
231
+ "black pickup truck, pickup truck; blue sky, sky",
232
+ ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
233
+ ],
234
+ [
235
+ "demo/examples/ade.jpg",
236
+ "luggage, suitcase, baggage;handbag",
237
+ ["ADE (150 categories)"],
238
+ ],
239
+ [
240
+ "demo/examples/ego4d.jpg",
241
+ "faucet, tap; kitchen paper, paper towels",
242
+ ["COCO (133 categories)"],
243
+ ],
244
+ ]
245
+
246
+
247
+ def build_demo_classes_and_metadata(vocab, label_list):
248
+ extra_classes = []
249
+
250
+ if vocab:
251
+ for words in vocab.split(";"):
252
+ extra_classes.append([word.strip() for word in words.split(",")])
253
+ extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
254
+
255
+ demo_thing_classes = extra_classes
256
+ demo_stuff_classes = []
257
+ demo_thing_colors = extra_colors
258
+ demo_stuff_colors = []
259
+
260
+ if any("COCO" in label for label in label_list):
261
+ demo_thing_classes += COCO_THING_CLASSES
262
+ demo_stuff_classes += COCO_STUFF_CLASSES
263
+ demo_thing_colors += COCO_THING_COLORS
264
+ demo_stuff_colors += COCO_STUFF_COLORS
265
+ if any("ADE" in label for label in label_list):
266
+ demo_thing_classes += ADE_THING_CLASSES
267
+ demo_stuff_classes += ADE_STUFF_CLASSES
268
+ demo_thing_colors += ADE_THING_COLORS
269
+ demo_stuff_colors += ADE_STUFF_COLORS
270
+ if any("LVIS" in label for label in label_list):
271
+ demo_thing_classes += LVIS_CLASSES
272
+ demo_thing_colors += LVIS_COLORS
273
+
274
+ MetadataCatalog.pop("odise_demo_metadata", None)
275
+ demo_metadata = MetadataCatalog.get("odise_demo_metadata")
276
+ demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
277
+ demo_metadata.stuff_classes = [
278
+ *demo_metadata.thing_classes,
279
+ *[c[0] for c in demo_stuff_classes],
280
+ ]
281
+ demo_metadata.thing_colors = demo_thing_colors
282
+ demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
283
+ demo_metadata.stuff_dataset_id_to_contiguous_id = {
284
+ idx: idx for idx in range(len(demo_metadata.stuff_classes))
285
+ }
286
+ demo_metadata.thing_dataset_id_to_contiguous_id = {
287
+ idx: idx for idx in range(len(demo_metadata.thing_classes))
288
+ }
289
+
290
+ demo_classes = demo_thing_classes + demo_stuff_classes
291
+
292
+ return demo_classes, demo_metadata
293
+
294
+
295
+ def inference(image_path, vocab, label_list):
296
+
297
+ logger.info("building class names")
298
+ demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
299
+ with ExitStack() as stack:
300
+ inference_model = OpenPanopticInference(
301
+ model=model,
302
+ labels=demo_classes,
303
+ metadata=demo_metadata,
304
+ semantic_on=False,
305
+ instance_on=False,
306
+ panoptic_on=True,
307
+ )
308
+ stack.enter_context(inference_context(inference_model))
309
+ stack.enter_context(torch.no_grad())
310
+
311
+ demo = VisualizationDemo(inference_model, demo_metadata, aug)
312
+ img = utils.read_image(image_path, format="RGB")
313
+ _, visualized_output = demo.run_on_image(img)
314
+ return Image.fromarray(visualized_output.get_image())
315
+
316
+
317
+ with gr.Blocks(title=title) as demo:
318
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
319
+ gr.Markdown(description)
320
+ input_components = []
321
+ output_components = []
322
+
323
+ with gr.Row():
324
+ output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
325
+ output_components.append(output_image_gr)
326
+
327
+ with gr.Row().style(equal_height=True, mobile_collapse=True):
328
+ with gr.Column(scale=3, variant="panel") as input_component_column:
329
+ input_image_gr = gr.inputs.Image(type="filepath")
330
+ extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
331
+ category_list_gr = gr.inputs.CheckboxGroup(
332
+ choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
333
+ default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
334
+ label="Category to use",
335
+ )
336
+ input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])
337
+
338
+ with gr.Column(scale=2):
339
+ examples_handler = gr.Examples(
340
+ examples=examples,
341
+ inputs=[c for c in input_components if not isinstance(c, gr.State)],
342
+ outputs=[c for c in output_components if not isinstance(c, gr.State)],
343
+ fn=inference,
344
+ cache_examples=torch.cuda.is_available(),
345
+ examples_per_page=5,
346
+ )
347
+ with gr.Row():
348
+ clear_btn = gr.Button("Clear")
349
+ submit_btn = gr.Button("Submit", variant="primary")
350
+
351
+ gr.Markdown(article)
352
+
353
+ submit_btn.click(
354
+ inference,
355
+ input_components,
356
+ output_components,
357
+ api_name="predict",
358
+ scroll_to_output=True,
359
+ )
360
+
361
+ clear_btn.click(
362
+ None,
363
+ [],
364
+ (input_components + output_components + [input_component_column]),
365
+ _js=f"""() => {json.dumps(
366
+ [component.cleared_value if hasattr(component, "cleared_value") else None
367
+ for component in input_components + output_components] + (
368
+ [gr.Column.update(visible=True)]
369
+ )
370
+ + ([gr.Column.update(visible=False)])
371
+ )}
372
+ """,
373
+ )
374
+
375
+ demo.launch()
demo/examples/ade.jpg ADDED
demo/examples/coco.jpg ADDED
demo/examples/ego4d.jpg ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libtinfo5
2
+ libsm6
3
+ libxext6
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==1.13.1+cu116
3
+ torchvision==0.14.1+cu116
4
+ xformers==0.0.16
5
+ numpy==1.23.5
6
+ matplotlib==3.7.1
7
+ pillow==9.4.0