File size: 6,542 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from typing import List
import tops
from torchvision.transforms.functional import InterpolationMode, resize
from densepose.data.utils import get_class_to_mesh_name_mapping
from densepose import add_densepose_config
from densepose.structures import DensePoseEmbeddingPredictorOutput
from densepose.vis.extractor import DensePoseOutputsExtractor
from densepose.modeling import build_densepose_embedder
from detectron2.config import get_cfg
from detectron2.data.transforms import ResizeShortestEdge
from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model


model_urls = {
    "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
    "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
}


def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
    assert len(S.shape) == 3
    H, W = imshape
    N = len(boxes_XYXY)
    segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
    boxes_XYXY = boxes_XYXY.long()
    for i in range(N):
        x0, y0, x1, y1 = boxes_XYXY[i]
        assert x0 >= 0 and y0 >= 0
        assert x1 <= imshape[1]
        assert y1 <= imshape[0]
        h = y1 - y0
        w = x1 - x0
        segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
    return segmentation


class CSEDetector:

    def __init__(
            self,
            cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
            cfg_2_download: List[str] = [
                "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
                "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
                "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
            score_thres: float = 0.9,
            nms_thresh: float = None,
    ) -> None:
        with tops.logger.capture_log_stdout():
            cfg = get_cfg()
            self.device = tops.get_device()
            add_densepose_config(cfg)
        cfg_path = tops.download_file(cfg_url)
        for p in cfg_2_download:
            tops.download_file(p)
        with tops.logger.capture_log_stdout():
            cfg.merge_from_file(cfg_path)
        assert cfg_url in model_urls, cfg_url
        model_path = tops.download_file(model_urls[cfg_url])
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
        if nms_thresh is not None:
            cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
        cfg.MODEL.WEIGHTS = str(model_path)
        cfg.MODEL.DEVICE = str(self.device)
        cfg.freeze()
        with tops.logger.capture_log_stdout():
            self.model = build_model(cfg)
            self.model.eval()
            DetectionCheckpointer(self.model).load(str(model_path))
            self.input_format = cfg.INPUT.FORMAT
            self.densepose_extractor = DensePoseOutputsExtractor()
            self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)

            self.embedder = build_densepose_embedder(cfg)
            self.mesh_vertex_embeddings = {
                mesh_name: self.embedder(mesh_name).to(self.device)
                for mesh_name in self.class_to_mesh_name.values()
                if self.embedder.has_embeddings(mesh_name)
            }
            self.cfg = cfg
            self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
        tops.logger.log("CSEDetector built.")

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def resize_im(self, im):
        H, W = im.shape[1:]
        newH, newW = ResizeShortestEdge.get_output_shape(
            H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
        return resize(
            im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)

    @torch.no_grad()
    def forward(self, im):
        assert im.dtype == torch.uint8
        if self.input_format == "BGR":
            im = im.flip(0)
        H, W = im.shape[1:]
        im = self.resize_im(im)
        output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
        scores = output.get("scores")
        if len(scores) == 0:
            return dict(
                instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
                instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
                embed_map=self.mesh_vertex_embeddings["smpl_27554"],
                bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
                im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
                scores=torch.empty((0), dtype=torch.float, device=im.device)
            )
        pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
        assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
        S = pred_densepose.coarse_segm.argmax(dim=1)  # Segmentation channel Nx2xHxW (2 because only 2 classes)
        E = pred_densepose.embedding
        mesh_name = self.class_to_mesh_name[classes[0]]
        assert mesh_name == "smpl_27554"
        x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
        boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
        boxes_XYXY = boxes_XYXY.round_().long()

        non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
        S = S[non_empty_boxes]
        E = E[non_empty_boxes]
        boxes_XYXY = boxes_XYXY[non_empty_boxes]
        scores = scores[non_empty_boxes]
        im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
        return dict(
            instance_segmentation=S, instance_embedding=E,
            bbox_XYXY=boxes_XYXY,
            im_segmentation=im_segmentation,
            scores=scores.view(-1))