Spaces:
Sleeping
Sleeping
Update mask_adapter/sam_maskadapter.py
Browse files- mask_adapter/sam_maskadapter.py +19 -12
mask_adapter/sam_maskadapter.py
CHANGED
@@ -227,18 +227,18 @@ class SAMPointVisualizationDemo(object):
|
|
227 |
self.mask_adapter = mask_adapter
|
228 |
|
229 |
|
230 |
-
from .data.datasets import openseg_classes
|
231 |
|
232 |
-
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
233 |
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
|
234 |
|
235 |
-
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
236 |
-
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
237 |
#print(coco_metadata)
|
238 |
-
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
|
239 |
-
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
|
240 |
|
241 |
-
self.class_names = thing_classes + stuff_classes + lvis_classes
|
242 |
#self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
|
243 |
|
244 |
self.class_names = self._load_class_names()
|
@@ -248,9 +248,11 @@ class SAMPointVisualizationDemo(object):
|
|
248 |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
249 |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
250 |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
254 |
|
255 |
|
256 |
def extract_features_convnext(self, x):
|
@@ -280,7 +282,9 @@ class SAMPointVisualizationDemo(object):
|
|
280 |
|
281 |
return clip_vis_dense
|
282 |
|
283 |
-
def run_on_image_with_points(self, ori_image, points,text_features):
|
|
|
|
|
284 |
height, width, _ = ori_image.shape
|
285 |
|
286 |
image = ori_image
|
@@ -357,7 +361,10 @@ class SAMPointVisualizationDemo(object):
|
|
357 |
|
358 |
return None, Image.fromarray(overlay)
|
359 |
|
360 |
-
def run_on_image_with_boxes(self, ori_image, bbox,text_features):
|
|
|
|
|
|
|
361 |
height, width, _ = ori_image.shape
|
362 |
|
363 |
image = ori_image
|
|
|
227 |
self.mask_adapter = mask_adapter
|
228 |
|
229 |
|
230 |
+
#from .data.datasets import openseg_classes
|
231 |
|
232 |
+
#COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
233 |
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
|
234 |
|
235 |
+
#thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
236 |
+
#stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
237 |
#print(coco_metadata)
|
238 |
+
#lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
|
239 |
+
#lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
|
240 |
|
241 |
+
#self.class_names = thing_classes + stuff_classes + lvis_classes
|
242 |
#self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
|
243 |
|
244 |
self.class_names = self._load_class_names()
|
|
|
248 |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
249 |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
250 |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
251 |
+
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
|
252 |
+
ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
|
253 |
+
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
|
254 |
+
class_names = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
|
255 |
+
return [ class_name for class_name in class_names ]
|
256 |
|
257 |
|
258 |
def extract_features_convnext(self, x):
|
|
|
282 |
|
283 |
return clip_vis_dense
|
284 |
|
285 |
+
def run_on_image_with_points(self, ori_image, points,text_features,class_names=None):
|
286 |
+
if class_names != None:
|
287 |
+
self.class_names = class_names
|
288 |
height, width, _ = ori_image.shape
|
289 |
|
290 |
image = ori_image
|
|
|
361 |
|
362 |
return None, Image.fromarray(overlay)
|
363 |
|
364 |
+
def run_on_image_with_boxes(self, ori_image, bbox,text_features,class_names=None):
|
365 |
+
if class_names != None:
|
366 |
+
self.class_names = class_names
|
367 |
+
|
368 |
height, width, _ = ori_image.shape
|
369 |
|
370 |
image = ori_image
|