deep_privacy2 / dp2 /detection /models /keypoint_maskrcnn.py
haakohu's picture
fix
44539fc
raw
history blame
4.58 kB
import numpy as np
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
from detectron2.data.transforms import ResizeShortestEdge
from detectron2.structures import Instances
from detectron2 import model_zoo
from detectron2.config import instantiate
from detectron2.config import LazyCall as L
from PIL import Image
import tops
import functools
from torchvision.transforms.functional import resize
def get_rn50_fpn_keypoint_rcnn(weight_path: str):
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
from detectron2.layers import ShapeSpec
model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
model.roi_heads.update(
num_classes=1,
keypoint_in_features=["p2", "p3", "p4", "p5"],
keypoint_pooler=L(ROIPooler)(
output_size=14,
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
sampling_ratio=0,
pooler_type="ROIAlignV2",
),
keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
input_shape=ShapeSpec(channels=256, width=14, height=14),
num_keypoints=17,
conv_dims=[512] * 8,
loss_normalizer="visible",
),
)
# Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
# 1000 proposals per-image is found to hurt box AP.
# Therefore we increase it to 1500 per-image.
model.proposal_generator.post_nms_topk = (1500, 1000)
# Keypoint AP degrades (though box AP improves) when using plain L1 loss
model.roi_heads.box_predictor.smooth_l1_beta = 0.5
model = instantiate(model)
dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
test_transform = instantiate(dataloader.test.mapper.augmentations)
DetectionCheckpointer(model).load(weight_path)
return model, test_transform
models = {
"rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
}
class KeypointMaskRCNN:
def __init__(self, model_name: str, score_threshold: float) -> None:
assert model_name in models, f"Did not find {model_name} in models"
model, test_transform = models[model_name]()
self.model = model.eval().to(tops.get_device())
if isinstance(self.model.roi_heads, CascadeROIHeads):
for head in self.model.roi_heads.box_predictors:
assert hasattr(head, "test_score_thresh")
head.test_score_thresh = score_threshold
else:
assert isinstance(self.model.roi_heads, StandardROIHeads)
assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
self.test_transform = test_transform
assert len(self.test_transform) == 1
self.test_transform = self.test_transform[0]
assert isinstance(self.test_transform, ResizeShortestEdge)
assert self.test_transform.interp == Image.BILINEAR
self.image_format = self.model.input_format
def resize_im(self, im):
H, W = im.shape[-2:]
if self.test_transform.is_range:
size = np.random.randint(
self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
else:
size = np.random.choice(self.test_transform.short_edge_length)
newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
return resize(
im, (newH, newW), antialias=True)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@torch.no_grad()
def forward(self, im: torch.Tensor):
assert im.ndim == 3
if self.image_format == "BGR":
im = im.flip(0)
H, W = im.shape[-2:]
im = im.float()
im = self.resize_im(im)
inputs = dict(image=im, height=H, width=W)
# instances contains
# dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
instances = self.model([inputs])[0]["instances"]
return dict(
scores=instances.get("scores").cpu(),
segmentation=instances.get("pred_masks").cpu(),
keypoints=instances.get("pred_keypoints").cpu()
)