Spaces:
Running
on
Zero
Running
on
Zero
Update mask_adapter/sam_maskadapter.py
Browse files
mask_adapter/sam_maskadapter.py
CHANGED
@@ -343,9 +343,6 @@ class SAMPointVisualizationDemo(object):
|
|
343 |
alpha = 0.5
|
344 |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
345 |
|
346 |
-
# Draw boundary (contours) on the overlay
|
347 |
-
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
348 |
-
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary
|
349 |
|
350 |
# Add label based on the class with the highest score
|
351 |
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
@@ -358,4 +355,78 @@ class SAMPointVisualizationDemo(object):
|
|
358 |
# Put text near the point
|
359 |
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
return None, Image.fromarray(overlay)
|
|
|
343 |
alpha = 0.5
|
344 |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
345 |
|
|
|
|
|
|
|
346 |
|
347 |
# Add label based on the class with the highest score
|
348 |
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
|
|
355 |
# Put text near the point
|
356 |
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
357 |
|
358 |
+
return None, Image.fromarray(overlay)
|
359 |
+
|
360 |
+
def run_on_image_with_boxes(self, ori_image, bbox,text_features):
|
361 |
+
height, width, _ = ori_image.shape
|
362 |
+
|
363 |
+
image = ori_image
|
364 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
365 |
+
# ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
|
366 |
+
|
367 |
+
|
368 |
+
with torch.no_grad():
|
369 |
+
self.predictor.set_image(image)
|
370 |
+
masks, _, _ = self.predictor.predict(box=bbox[None, :], multimask_output=False)
|
371 |
+
|
372 |
+
pred_masks = BitMasks(masks)
|
373 |
+
|
374 |
+
image = torch.as_tensor(image.astype("float16").transpose(2, 0, 1))
|
375 |
+
|
376 |
+
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1)
|
377 |
+
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1)
|
378 |
+
|
379 |
+
image = (image - pixel_mean) / pixel_std
|
380 |
+
image = image.unsqueeze(0)
|
381 |
+
|
382 |
+
# txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
|
383 |
+
# text = open_clip.tokenize(txts)
|
384 |
+
|
385 |
+
with torch.no_grad():
|
386 |
+
# text_features = self.clip_model.encode_text(text.cuda())
|
387 |
+
# text_features /= text_features.norm(dim=-1, keepdim=True)
|
388 |
+
#np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy())
|
389 |
+
#text_features = self.text_embedding.to(self.mask_adapter.device)
|
390 |
+
features = self.extract_features_convnext(image.to(text_features).float())
|
391 |
+
clip_feature = features['clip_vis_dense']
|
392 |
+
|
393 |
+
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature)
|
394 |
+
|
395 |
+
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).to(text_features).float())
|
396 |
+
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False)
|
397 |
+
|
398 |
+
B, C = clip_feature.size(0), clip_feature.size(1)
|
399 |
+
N = maps_for_pooling.size(1)
|
400 |
+
num_instances = N // 16
|
401 |
+
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1)
|
402 |
+
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1))
|
403 |
+
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature)
|
404 |
+
pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous())
|
405 |
+
|
406 |
+
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1)
|
407 |
+
class_preds = class_preds.squeeze(0)
|
408 |
+
|
409 |
+
# Resize mask to match original image size
|
410 |
+
pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size
|
411 |
+
|
412 |
+
# Create an overlay for the mask with a transparent background (using alpha transparency)
|
413 |
+
overlay = ori_image.copy()
|
414 |
+
mask_colored = np.zeros_like(ori_image)
|
415 |
+
mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask
|
416 |
+
|
417 |
+
alpha = 0.5
|
418 |
+
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
419 |
+
|
420 |
+
|
421 |
+
# Add label based on the class with the highest score
|
422 |
+
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
423 |
+
label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}"
|
424 |
+
|
425 |
+
# Dynamically place the label near the clicked point
|
426 |
+
text_x = min(width - 200, bbox[0] + 20) # Add some offset from the point
|
427 |
+
text_y = min(height - 30, bbox[1] + 20) # Ensure the text does not go out of bounds
|
428 |
+
|
429 |
+
# Put text near the point
|
430 |
+
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
431 |
+
|
432 |
return None, Image.fromarray(overlay)
|