xvjiarui commited on
Commit
ef9b1d9
·
0 Parent(s):

init commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ODISE
3
+ emoji: 🤗
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.20.0
8
+ app_file: app.py
9
+ pinned: true
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
3
+ #
4
+ # This work is made available under the Nvidia Source Code License.
5
+ # To view a copy of this license, visit
6
+ # https://github.com/NVlabs/ODISE/blob/main/LICENSE
7
+ #
8
+ # Written by Jiarui Xu
9
+ # ------------------------------------------------------------------------------
10
+
11
+ import os
12
+ token = os.environ["GITHUB_TOKEN"]
13
+ os.system(f"pip install git+https://xvjiarui:{token}@github.com/xvjiarui/ODISE_NV.git")
14
+
15
+ import itertools
16
+ import json
17
+ from contextlib import ExitStack
18
+ import gradio as gr
19
+ import torch
20
+ from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES
21
+ from PIL import Image
22
+ from torch.cuda.amp import autocast
23
+
24
+ from detectron2.config import instantiate
25
+ from detectron2.data import MetadataCatalog
26
+ from detectron2.data import detection_utils as utils
27
+ from detectron2.data import transforms as T
28
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
29
+ from detectron2.evaluation import inference_context
30
+ from detectron2.utils.env import seed_all_rng
31
+ from detectron2.utils.logger import setup_logger
32
+ from detectron2.utils.visualizer import ColorMode, Visualizer, random_color
33
+
34
+ from odise import model_zoo
35
+ from odise.checkpoint import ODISECheckpointer
36
+ from odise.config import instantiate_odise
37
+ from odise.data import get_openseg_labels
38
+ from odise.modeling.wrapper import OpenPanopticInference
39
+ from odise.utils.file_io import ODISEHandler, PathManager
40
+ from odise.model_zoo.model_zoo import _ModelZooUrls
41
+
42
+ for k in ODISEHandler.URLS:
43
+ ODISEHandler.URLS[k] = ODISEHandler.URLS[k].replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
44
+ PathManager.register_handler(ODISEHandler())
45
+ _ModelZooUrls.PREFIX = _ModelZooUrls.PREFIX.replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
46
+
47
+ setup_logger()
48
+ logger = setup_logger(name="odise")
49
+
50
+ COCO_THING_CLASSES = [
51
+ label
52
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
53
+ if COCO_CATEGORIES[idx]["isthing"] == 1
54
+ ]
55
+ COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1]
56
+ COCO_STUFF_CLASSES = [
57
+ label
58
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
59
+ if COCO_CATEGORIES[idx]["isthing"] == 0
60
+ ]
61
+ COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0]
62
+
63
+ ADE_THING_CLASSES = [
64
+ label
65
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
66
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 1
67
+ ]
68
+ ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1]
69
+ ADE_STUFF_CLASSES = [
70
+ label
71
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
72
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 0
73
+ ]
74
+ ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0]
75
+
76
+ LVIS_CLASSES = get_openseg_labels("lvis_1203", True)
77
+ # use beautiful coco colors
78
+ LVIS_COLORS = list(
79
+ itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES))
80
+ )
81
+
82
+
83
+ class VisualizationDemo(object):
84
+ def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE):
85
+ """
86
+ Args:
87
+ model (nn.Module):
88
+ metadata (MetadataCatalog): image metadata.
89
+ instance_mode (ColorMode):
90
+ parallel (bool): whether to run the model in different processes from visualization.
91
+ Useful since the visualization logic can be slow.
92
+ """
93
+ self.model = model
94
+ self.metadata = metadata
95
+ self.aug = aug
96
+ self.cpu_device = torch.device("cpu")
97
+ self.instance_mode = instance_mode
98
+
99
+ def predict(self, original_image):
100
+ """
101
+ Args:
102
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
103
+
104
+ Returns:
105
+ predictions (dict):
106
+ the output of the model for one image only.
107
+ See :doc:`/tutorials/models` for details about the format.
108
+ """
109
+ height, width = original_image.shape[:2]
110
+ aug_input = T.AugInput(original_image, sem_seg=None)
111
+ self.aug(aug_input)
112
+ image = aug_input.image
113
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
114
+
115
+ inputs = {"image": image, "height": height, "width": width}
116
+ logger.info("forwarding")
117
+ with autocast():
118
+ predictions = self.model([inputs])[0]
119
+ logger.info("done")
120
+ return predictions
121
+
122
+ def run_on_image(self, image):
123
+ """
124
+ Args:
125
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
126
+ This is the format used by OpenCV.
127
+ Returns:
128
+ predictions (dict): the output of the model.
129
+ vis_output (VisImage): the visualized image output.
130
+ """
131
+ vis_output = None
132
+ predictions = self.predict(image)
133
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
134
+ if "panoptic_seg" in predictions:
135
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
136
+ vis_output = visualizer.draw_panoptic_seg(
137
+ panoptic_seg.to(self.cpu_device), segments_info
138
+ )
139
+ else:
140
+ if "sem_seg" in predictions:
141
+ vis_output = visualizer.draw_sem_seg(
142
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
143
+ )
144
+ if "instances" in predictions:
145
+ instances = predictions["instances"].to(self.cpu_device)
146
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
147
+
148
+ return predictions, vis_output
149
+
150
+
151
+ cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True)
152
+
153
+ cfg.model.overlap_threshold = 0
154
+ cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ seed_all_rng(42)
156
+
157
+ dataset_cfg = cfg.dataloader.test
158
+ wrapper_cfg = cfg.dataloader.wrapper
159
+
160
+ aug = instantiate(dataset_cfg.mapper).augmentations
161
+
162
+ model = instantiate_odise(cfg.model)
163
+ model.to(torch.float16)
164
+ model.to(cfg.train.device)
165
+ ODISECheckpointer(model).load(cfg.train.init_checkpoint)
166
+
167
+
168
+ title = "ODISE"
169
+ description = """
170
+ Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n
171
+ You may click on of the examples or upload your own image. \n
172
+
173
+ ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma).
174
+ The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class.
175
+ The first word will be displayed as the class name.
176
+ """ # noqa
177
+
178
+ article = """
179
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p>
180
+ """ # noqa
181
+
182
+ examples = [
183
+ [
184
+ "demo/examples/coco.jpg",
185
+ "black pickup truck, pickup truck; blue sky, sky",
186
+ ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
187
+ ],
188
+ [
189
+ "demo/examples/ade.jpg",
190
+ "luggage, suitcase, baggage;handbag",
191
+ ["ADE (150 categories)"],
192
+ ],
193
+ [
194
+ "demo/examples/ego4d.jpg",
195
+ "faucet, tap; kitchen paper, paper towels",
196
+ ["COCO (133 categories)"],
197
+ ],
198
+ ]
199
+
200
+
201
+ def build_demo_classes_and_metadata(vocab, label_list):
202
+ extra_classes = []
203
+
204
+ if vocab:
205
+ for words in vocab.split(";"):
206
+ extra_classes.append([word.strip() for word in words.split(",")])
207
+ extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
208
+
209
+ demo_thing_classes = extra_classes
210
+ demo_stuff_classes = []
211
+ demo_thing_colors = extra_colors
212
+ demo_stuff_colors = []
213
+
214
+ if any("COCO" in label for label in label_list):
215
+ demo_thing_classes += COCO_THING_CLASSES
216
+ demo_stuff_classes += COCO_STUFF_CLASSES
217
+ demo_thing_colors += COCO_THING_COLORS
218
+ demo_stuff_colors += COCO_STUFF_COLORS
219
+ if any("ADE" in label for label in label_list):
220
+ demo_thing_classes += ADE_THING_CLASSES
221
+ demo_stuff_classes += ADE_STUFF_CLASSES
222
+ demo_thing_colors += ADE_THING_COLORS
223
+ demo_stuff_colors += ADE_STUFF_COLORS
224
+ if any("LVIS" in label for label in label_list):
225
+ demo_thing_classes += LVIS_CLASSES
226
+ demo_thing_colors += LVIS_COLORS
227
+
228
+ MetadataCatalog.pop("odise_demo_metadata", None)
229
+ demo_metadata = MetadataCatalog.get("odise_demo_metadata")
230
+ demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
231
+ demo_metadata.stuff_classes = [
232
+ *demo_metadata.thing_classes,
233
+ *[c[0] for c in demo_stuff_classes],
234
+ ]
235
+ demo_metadata.thing_colors = demo_thing_colors
236
+ demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
237
+ demo_metadata.stuff_dataset_id_to_contiguous_id = {
238
+ idx: idx for idx in range(len(demo_metadata.stuff_classes))
239
+ }
240
+ demo_metadata.thing_dataset_id_to_contiguous_id = {
241
+ idx: idx for idx in range(len(demo_metadata.thing_classes))
242
+ }
243
+
244
+ demo_classes = demo_thing_classes + demo_stuff_classes
245
+
246
+ return demo_classes, demo_metadata
247
+
248
+
249
+ def inference(image_path, vocab, label_list):
250
+
251
+ logger.info("building class names")
252
+ demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
253
+ with ExitStack() as stack:
254
+ inference_model = OpenPanopticInference(
255
+ model=model,
256
+ labels=demo_classes,
257
+ metadata=demo_metadata,
258
+ semantic_on=False,
259
+ instance_on=False,
260
+ panoptic_on=True,
261
+ )
262
+ stack.enter_context(inference_context(inference_model))
263
+ stack.enter_context(torch.no_grad())
264
+
265
+ demo = VisualizationDemo(inference_model, demo_metadata, aug)
266
+ img = utils.read_image(image_path, format="RGB")
267
+ _, visualized_output = demo.run_on_image(img)
268
+ return Image.fromarray(visualized_output.get_image())
269
+
270
+
271
+ with gr.Blocks(title=title) as demo:
272
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
273
+ gr.Markdown(description)
274
+ input_components = []
275
+ output_components = []
276
+
277
+ with gr.Row():
278
+ output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
279
+ output_components.append(output_image_gr)
280
+
281
+ with gr.Row().style(equal_height=True, mobile_collapse=True):
282
+ with gr.Column(scale=3, variant="panel") as input_component_column:
283
+ input_image_gr = gr.inputs.Image(type="filepath")
284
+ extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
285
+ category_list_gr = gr.inputs.CheckboxGroup(
286
+ choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
287
+ default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
288
+ label="Category to use",
289
+ )
290
+ input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])
291
+
292
+ with gr.Column(scale=2):
293
+ examples_handler = gr.Examples(
294
+ examples=examples,
295
+ inputs=[c for c in input_components if not isinstance(c, gr.State)],
296
+ outputs=[c for c in output_components if not isinstance(c, gr.State)],
297
+ fn=inference,
298
+ cache_examples=torch.cuda.is_available(),
299
+ examples_per_page=5,
300
+ )
301
+ with gr.Row():
302
+ clear_btn = gr.Button("Clear")
303
+ submit_btn = gr.Button("Submit", variant="primary")
304
+
305
+ gr.Markdown(article)
306
+
307
+ submit_btn.click(
308
+ inference,
309
+ input_components,
310
+ output_components,
311
+ api_name="predict",
312
+ scroll_to_output=True,
313
+ )
314
+
315
+ clear_btn.click(
316
+ None,
317
+ [],
318
+ (input_components + output_components + [input_component_column]),
319
+ _js=f"""() => {json.dumps(
320
+ [component.cleared_value if hasattr(component, "cleared_value") else None
321
+ for component in input_components + output_components] + (
322
+ [gr.Column.update(visible=True)]
323
+ )
324
+ + ([gr.Column.update(visible=False)])
325
+ )}
326
+ """,
327
+ )
328
+
329
+ demo.launch()
demo/examples/ade.jpg ADDED
demo/examples/coco.jpg ADDED
demo/examples/ego4d.jpg ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libtinfo5
2
+ libsm6
3
+ libxext6
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==1.13.1+cu116
3
+ torchvision==0.14.1+cu116
4
+ xformers==0.0.16
5
+ numpy<=1.21.5