Spaces:
Runtime error
Runtime error
| import torch | |
| import lzma | |
| from pathlib import Path | |
| from dp2.detection.base import BaseDetector | |
| from .mask_rcnn import MaskRCNNDetector | |
| from ..structures import PersonDetection | |
| from tops import logger | |
| from .vit_pose.vit_pose import VitPoseModel | |
| from ..utils import masks_to_boxes | |
| def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): | |
| assert len(box1.shape) == 2 | |
| assert len(box2.shape) == 2 | |
| box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) | |
| # This can be batched | |
| for i, box in enumerate(box1): | |
| is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) | |
| is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) | |
| is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) | |
| box1_inside[i] = is_outside.logical_not().any() | |
| return box1_inside | |
| class MaskRCNNVitPose(BaseDetector): | |
| def __init__( | |
| self, | |
| mask_rcnn_cfg, | |
| cse_post_process_cfg, | |
| score_threshold: float, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) | |
| self.vit_pose = VitPoseModel("vit_huge") | |
| self.cse_post_process_cfg = cse_post_process_cfg | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def load_from_cache(self, cache_path: Path): | |
| logger.log(f"Loading detection from cache path: {cache_path}",) | |
| with lzma.open(cache_path, "rb") as fp: | |
| state_dict = torch.load(fp, map_location="cpu") | |
| kwargs = dict( | |
| post_process_cfg=self.cse_post_process_cfg, | |
| ) | |
| return [ | |
| state["cls"].from_state_dict(**kwargs, state_dict=state) | |
| for state in state_dict | |
| ] | |
| def forward(self, im: torch.Tensor): | |
| maskrcnn_dets = self.mask_rcnn(im) | |
| maskrcnn_person = { | |
| k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items() | |
| } | |
| boxes = masks_to_boxes(maskrcnn_person["segmentation"]) | |
| keypoints = self.vit_pose(im, boxes).cpu() | |
| keypoints[:, :, -1] = keypoints[:, :, -1] >= 0.3 | |
| persons_without_cse = PersonDetection( | |
| maskrcnn_person["segmentation"], **self.cse_post_process_cfg, | |
| orig_imshape_CHW=im.shape, | |
| keypoints=keypoints | |
| ) | |
| persons_without_cse.pre_process() | |
| all_detections = [persons_without_cse] | |
| return all_detections | |