wondervictor commited on
Commit
fec2205
·
verified ·
1 Parent(s): 55249e4

Update mask_adapter/sam_maskadapter.py

Browse files
Files changed (1) hide show
  1. mask_adapter/sam_maskadapter.py +19 -12
mask_adapter/sam_maskadapter.py CHANGED
@@ -227,18 +227,18 @@ class SAMPointVisualizationDemo(object):
227
  self.mask_adapter = mask_adapter
228
 
229
 
230
- from .data.datasets import openseg_classes
231
 
232
- COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
233
  #COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
234
 
235
- thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
236
- stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
237
  #print(coco_metadata)
238
- lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
239
- lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
240
 
241
- self.class_names = thing_classes + stuff_classes + lvis_classes
242
  #self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
243
 
244
  self.class_names = self._load_class_names()
@@ -248,9 +248,11 @@ class SAMPointVisualizationDemo(object):
248
  COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
249
  thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
250
  stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
251
- lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
252
- lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
253
- return thing_classes + stuff_classes + lvis_classes
 
 
254
 
255
 
256
  def extract_features_convnext(self, x):
@@ -280,7 +282,9 @@ class SAMPointVisualizationDemo(object):
280
 
281
  return clip_vis_dense
282
 
283
- def run_on_image_with_points(self, ori_image, points,text_features):
 
 
284
  height, width, _ = ori_image.shape
285
 
286
  image = ori_image
@@ -357,7 +361,10 @@ class SAMPointVisualizationDemo(object):
357
 
358
  return None, Image.fromarray(overlay)
359
 
360
- def run_on_image_with_boxes(self, ori_image, bbox,text_features):
 
 
 
361
  height, width, _ = ori_image.shape
362
 
363
  image = ori_image
 
227
  self.mask_adapter = mask_adapter
228
 
229
 
230
+ #from .data.datasets import openseg_classes
231
 
232
+ #COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
233
  #COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
234
 
235
+ #thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
236
+ #stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
237
  #print(coco_metadata)
238
+ #lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
239
+ #lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
240
 
241
+ #self.class_names = thing_classes + stuff_classes + lvis_classes
242
  #self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
243
 
244
  self.class_names = self._load_class_names()
 
248
  COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
249
  thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
250
  stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
251
+ ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
252
+ ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
253
+ ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
254
+ class_names = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
255
+ return [ class_name for class_name in class_names ]
256
 
257
 
258
  def extract_features_convnext(self, x):
 
282
 
283
  return clip_vis_dense
284
 
285
+ def run_on_image_with_points(self, ori_image, points,text_features,class_names=None):
286
+ if class_names != None:
287
+ self.class_names = class_names
288
  height, width, _ = ori_image.shape
289
 
290
  image = ori_image
 
361
 
362
  return None, Image.fromarray(overlay)
363
 
364
+ def run_on_image_with_boxes(self, ori_image, bbox,text_features,class_names=None):
365
+ if class_names != None:
366
+ self.class_names = class_names
367
+
368
  height, width, _ = ori_image.shape
369
 
370
  image = ori_image