wondervictor commited on
Commit
45caf8e
·
verified ·
1 Parent(s): 2ed7fa7

Update mask_adapter/sam_maskadapter.py

Browse files
Files changed (1) hide show
  1. mask_adapter/sam_maskadapter.py +74 -3
mask_adapter/sam_maskadapter.py CHANGED
@@ -343,9 +343,6 @@ class SAMPointVisualizationDemo(object):
343
  alpha = 0.5
344
  cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
345
 
346
- # Draw boundary (contours) on the overlay
347
- contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
348
- cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary
349
 
350
  # Add label based on the class with the highest score
351
  max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
@@ -358,4 +355,78 @@ class SAMPointVisualizationDemo(object):
358
  # Put text near the point
359
  cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  return None, Image.fromarray(overlay)
 
343
  alpha = 0.5
344
  cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
345
 
 
 
 
346
 
347
  # Add label based on the class with the highest score
348
  max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
 
355
  # Put text near the point
356
  cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
357
 
358
+ return None, Image.fromarray(overlay)
359
+
360
+ def run_on_image_with_boxes(self, ori_image, bbox,text_features):
361
+ height, width, _ = ori_image.shape
362
+
363
+ image = ori_image
364
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
365
+ # ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
366
+
367
+
368
+ with torch.no_grad():
369
+ self.predictor.set_image(image)
370
+ masks, _, _ = self.predictor.predict(box=bbox[None, :], multimask_output=False)
371
+
372
+ pred_masks = BitMasks(masks)
373
+
374
+ image = torch.as_tensor(image.astype("float16").transpose(2, 0, 1))
375
+
376
+ pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1)
377
+ pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1)
378
+
379
+ image = (image - pixel_mean) / pixel_std
380
+ image = image.unsqueeze(0)
381
+
382
+ # txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
383
+ # text = open_clip.tokenize(txts)
384
+
385
+ with torch.no_grad():
386
+ # text_features = self.clip_model.encode_text(text.cuda())
387
+ # text_features /= text_features.norm(dim=-1, keepdim=True)
388
+ #np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy())
389
+ #text_features = self.text_embedding.to(self.mask_adapter.device)
390
+ features = self.extract_features_convnext(image.to(text_features).float())
391
+ clip_feature = features['clip_vis_dense']
392
+
393
+ clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature)
394
+
395
+ semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).to(text_features).float())
396
+ maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False)
397
+
398
+ B, C = clip_feature.size(0), clip_feature.size(1)
399
+ N = maps_for_pooling.size(1)
400
+ num_instances = N // 16
401
+ maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1)
402
+ pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1))
403
+ pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature)
404
+ pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous())
405
+
406
+ class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1)
407
+ class_preds = class_preds.squeeze(0)
408
+
409
+ # Resize mask to match original image size
410
+ pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size
411
+
412
+ # Create an overlay for the mask with a transparent background (using alpha transparency)
413
+ overlay = ori_image.copy()
414
+ mask_colored = np.zeros_like(ori_image)
415
+ mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask
416
+
417
+ alpha = 0.5
418
+ cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
419
+
420
+
421
+ # Add label based on the class with the highest score
422
+ max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
423
+ label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}"
424
+
425
+ # Dynamically place the label near the clicked point
426
+ text_x = min(width - 200, bbox[0] + 20) # Add some offset from the point
427
+ text_y = min(height - 30, bbox[1] + 20) # Ensure the text does not go out of bounds
428
+
429
+ # Put text near the point
430
+ cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
431
+
432
  return None, Image.fromarray(overlay)