File size: 4,600 Bytes
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import torch
import torch.nn as nn
from mmcv import Config
from mmcv.runner import load_checkpoint

from mmdet.core import bbox2result
from mmdet.models import DETECTORS, BaseDetector
from projects.instance_segment_anything.models.segment_anything import sam_model_registry, SamPredictor
from .focalnet_dino.focalnet_dino_wrapper import FocalNetDINOWrapper
from .hdetr.hdetr_wrapper import HDetrWrapper


@DETECTORS.register_module()
class DetWrapperInstanceSAM(BaseDetector):
    wrapper_dict = {'hdetr': HDetrWrapper,
                    'focalnet_dino': FocalNetDINOWrapper}

    def __init__(self,
                 det_wrapper_type='hdetr',
                 det_wrapper_cfg=None,
                 det_model_ckpt=None,
                 num_classes=80,

                 model_type='vit_b',
                 sam_checkpoint=None,
                 use_sam_iou=True,

                 init_cfg=None,
                 train_cfg=None,
                 test_cfg=None):
        super(DetWrapperInstanceSAM, self).__init__(init_cfg)
        self.learnable_placeholder = nn.Embedding(1, 1)
        det_wrapper_cfg = Config(det_wrapper_cfg)
        assert det_wrapper_type in self.wrapper_dict.keys()
        self.det_model = self.wrapper_dict[det_wrapper_type](args=det_wrapper_cfg)
        if det_model_ckpt is not None:
            load_checkpoint(self.det_model.model,
                            filename=det_model_ckpt,
                            map_location='cpu')

        self.num_classes = num_classes

        # Segment Anything
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        _ = sam.to(device=self.learnable_placeholder.weight.device)
        self.predictor = SamPredictor(sam)
        self.use_sam_iou = use_sam_iou

    def init_weights(self):
        pass

    def simple_test(self, img, img_metas, ori_img, rescale=True):
        """Test without augmentation.
        Args:
            imgs (Tensor): A batch of images.
            img_metas (list[dict]): List of image information.
        """
        assert rescale
        assert len(img_metas) == 1
        # results: List[dict(scores, labels, boxes)]
        results = self.det_model.simple_test(img,
                                             img_metas,
                                             rescale)

        # Tensor(n,4), xyxy, ori image scale
        output_boxes = results[0]['boxes']

        self.predictor.set_image(ori_img)

        transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])

        # mask_pred: n,1,h,w
        # sam_score: n, 1
        mask_pred, sam_score, _ = self.predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
            return_logits=True,
        )
        # Tensor(n,h,w), raw mask pred
        mask_pred = mask_pred.squeeze(1)
        sam_score = sam_score.squeeze(-1)

        # Tensor(n,)
        label_pred = results[0]['labels']

        score_pred = results[0]['scores']

        # mask_pred: Tensor(n,h,w)
        # label_pred: Tensor(n,)
        # score_pred: Tensor(n,)
        # sam_score: Tensor(n,)
        mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
        if self.use_sam_iou:
            det_scores = score_pred * sam_score
        else:
            # n
            mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
                    mask_pred_binary.flatten(1).sum(1) + 1e-6)
            det_scores = score_pred * mask_scores_per_image
        # det_scores = score_pred
        mask_pred_binary = mask_pred_binary.bool()
        bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
        bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
        mask_results = [[] for _ in range(self.num_classes)]
        for j, label in enumerate(label_pred):
            mask = mask_pred_binary[j].detach().cpu().numpy()
            mask_results[label].append(mask)
        output_results = [(bbox_results, mask_results)]

        return output_results

    # not implemented:
    def aug_test(self, imgs, img_metas, **kwargs):
        raise NotImplementedError

    def onnx_export(self, img, img_metas):
        raise NotImplementedError

    async def async_simple_test(self, img, img_metas, **kwargs):
        raise NotImplementedError

    def forward_train(self, imgs, img_metas, **kwargs):
        raise NotImplementedError

    def extract_feat(self, imgs):
        raise NotImplementedError