wondervictor commited on
Commit
bcf38fb
·
verified ·
1 Parent(s): a06ecb5

Update mask_adapter/sam_maskadapter.py

Browse files
Files changed (1) hide show
  1. mask_adapter/sam_maskadapter.py +3 -2
mask_adapter/sam_maskadapter.py CHANGED
@@ -139,7 +139,8 @@ class SAMVisualizationDemo(object):
139
  image = (image - pixel_mean) / pixel_std
140
 
141
  image = image.unsqueeze(0)
142
-
 
143
  # if len(class_names) == 1:
144
  # class_names.append('others')
145
  # txts = [f'a photo of {cls_name}' for cls_name in class_names]
@@ -305,7 +306,7 @@ class SAMPointVisualizationDemo(object):
305
 
306
  # txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
307
  # text = open_clip.tokenize(txts)
308
-
309
  with torch.no_grad():
310
  # text_features = self.clip_model.encode_text(text.cuda())
311
  # text_features /= text_features.norm(dim=-1, keepdim=True)
 
139
  image = (image - pixel_mean) / pixel_std
140
 
141
  image = image.unsqueeze(0)
142
+
143
+ image = image.to(text_features)
144
  # if len(class_names) == 1:
145
  # class_names.append('others')
146
  # txts = [f'a photo of {cls_name}' for cls_name in class_names]
 
306
 
307
  # txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
308
  # text = open_clip.tokenize(txts)
309
+
310
  with torch.no_grad():
311
  # text_features = self.clip_model.encode_text(text.cuda())
312
  # text_features /= text_features.norm(dim=-1, keepdim=True)