Spaces:
Runtime error
Runtime error
Update mask_adapter/sam_maskadapter.py
Browse files
mask_adapter/sam_maskadapter.py
CHANGED
|
@@ -343,9 +343,6 @@ class SAMPointVisualizationDemo(object):
|
|
| 343 |
alpha = 0.5
|
| 344 |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
| 345 |
|
| 346 |
-
# Draw boundary (contours) on the overlay
|
| 347 |
-
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 348 |
-
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary
|
| 349 |
|
| 350 |
# Add label based on the class with the highest score
|
| 351 |
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
|
@@ -358,4 +355,78 @@ class SAMPointVisualizationDemo(object):
|
|
| 358 |
# Put text near the point
|
| 359 |
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
return None, Image.fromarray(overlay)
|
|
|
|
| 343 |
alpha = 0.5
|
| 344 |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
| 345 |
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
# Add label based on the class with the highest score
|
| 348 |
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
|
|
|
| 355 |
# Put text near the point
|
| 356 |
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 357 |
|
| 358 |
+
return None, Image.fromarray(overlay)
|
| 359 |
+
|
| 360 |
+
def run_on_image_with_boxes(self, ori_image, bbox,text_features):
|
| 361 |
+
height, width, _ = ori_image.shape
|
| 362 |
+
|
| 363 |
+
image = ori_image
|
| 364 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 365 |
+
# ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
with torch.no_grad():
|
| 369 |
+
self.predictor.set_image(image)
|
| 370 |
+
masks, _, _ = self.predictor.predict(box=bbox[None, :], multimask_output=False)
|
| 371 |
+
|
| 372 |
+
pred_masks = BitMasks(masks)
|
| 373 |
+
|
| 374 |
+
image = torch.as_tensor(image.astype("float16").transpose(2, 0, 1))
|
| 375 |
+
|
| 376 |
+
pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1)
|
| 377 |
+
pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1)
|
| 378 |
+
|
| 379 |
+
image = (image - pixel_mean) / pixel_std
|
| 380 |
+
image = image.unsqueeze(0)
|
| 381 |
+
|
| 382 |
+
# txts = [f'a photo of {cls_name}' for cls_name in self.class_names]
|
| 383 |
+
# text = open_clip.tokenize(txts)
|
| 384 |
+
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
# text_features = self.clip_model.encode_text(text.cuda())
|
| 387 |
+
# text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 388 |
+
#np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy())
|
| 389 |
+
#text_features = self.text_embedding.to(self.mask_adapter.device)
|
| 390 |
+
features = self.extract_features_convnext(image.to(text_features).float())
|
| 391 |
+
clip_feature = features['clip_vis_dense']
|
| 392 |
+
|
| 393 |
+
clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature)
|
| 394 |
+
|
| 395 |
+
semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).to(text_features).float())
|
| 396 |
+
maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False)
|
| 397 |
+
|
| 398 |
+
B, C = clip_feature.size(0), clip_feature.size(1)
|
| 399 |
+
N = maps_for_pooling.size(1)
|
| 400 |
+
num_instances = N // 16
|
| 401 |
+
maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1)
|
| 402 |
+
pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1))
|
| 403 |
+
pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature)
|
| 404 |
+
pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous())
|
| 405 |
+
|
| 406 |
+
class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1)
|
| 407 |
+
class_preds = class_preds.squeeze(0)
|
| 408 |
+
|
| 409 |
+
# Resize mask to match original image size
|
| 410 |
+
pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size
|
| 411 |
+
|
| 412 |
+
# Create an overlay for the mask with a transparent background (using alpha transparency)
|
| 413 |
+
overlay = ori_image.copy()
|
| 414 |
+
mask_colored = np.zeros_like(ori_image)
|
| 415 |
+
mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask
|
| 416 |
+
|
| 417 |
+
alpha = 0.5
|
| 418 |
+
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Add label based on the class with the highest score
|
| 422 |
+
max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions
|
| 423 |
+
label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}"
|
| 424 |
+
|
| 425 |
+
# Dynamically place the label near the clicked point
|
| 426 |
+
text_x = min(width - 200, bbox[0] + 20) # Add some offset from the point
|
| 427 |
+
text_y = min(height - 30, bbox[1] + 20) # Ensure the text does not go out of bounds
|
| 428 |
+
|
| 429 |
+
# Put text near the point
|
| 430 |
+
cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 431 |
+
|
| 432 |
return None, Image.fromarray(overlay)
|