Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads | |
from detectron2.data.transforms import ResizeShortestEdge | |
from detectron2.structures import Instances | |
from detectron2 import model_zoo | |
from detectron2.config import instantiate | |
from detectron2.config import LazyCall as L | |
from PIL import Image | |
import tops | |
import functools | |
from torchvision.transforms.functional import resize | |
def get_rn50_fpn_keypoint_rcnn(weight_path: str): | |
from detectron2.modeling.poolers import ROIPooler | |
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead | |
from detectron2.layers import ShapeSpec | |
model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model | |
model.roi_heads.update( | |
num_classes=1, | |
keypoint_in_features=["p2", "p3", "p4", "p5"], | |
keypoint_pooler=L(ROIPooler)( | |
output_size=14, | |
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), | |
sampling_ratio=0, | |
pooler_type="ROIAlignV2", | |
), | |
keypoint_head=L(KRCNNConvDeconvUpsampleHead)( | |
input_shape=ShapeSpec(channels=256, width=14, height=14), | |
num_keypoints=17, | |
conv_dims=[512] * 8, | |
loss_normalizer="visible", | |
), | |
) | |
# Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. | |
# 1000 proposals per-image is found to hurt box AP. | |
# Therefore we increase it to 1500 per-image. | |
model.proposal_generator.post_nms_topk = (1500, 1000) | |
# Keypoint AP degrades (though box AP improves) when using plain L1 loss | |
model.roi_heads.box_predictor.smooth_l1_beta = 0.5 | |
model = instantiate(model) | |
dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader | |
test_transform = instantiate(dataloader.test.mapper.augmentations) | |
DetectionCheckpointer(model).load(weight_path) | |
return model, test_transform | |
models = { | |
"rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth") | |
} | |
class KeypointMaskRCNN: | |
def __init__(self, model_name: str, score_threshold: float) -> None: | |
assert model_name in models, f"Did not find {model_name} in models" | |
model, test_transform = models[model_name]() | |
self.model = model.eval().to(tops.get_device()) | |
if isinstance(self.model.roi_heads, CascadeROIHeads): | |
for head in self.model.roi_heads.box_predictors: | |
assert hasattr(head, "test_score_thresh") | |
head.test_score_thresh = score_threshold | |
else: | |
assert isinstance(self.model.roi_heads, StandardROIHeads) | |
assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh") | |
self.model.roi_heads.box_predictor.test_score_thresh = score_threshold | |
self.test_transform = test_transform | |
assert len(self.test_transform) == 1 | |
self.test_transform = self.test_transform[0] | |
assert isinstance(self.test_transform, ResizeShortestEdge) | |
assert self.test_transform.interp == Image.BILINEAR | |
self.image_format = self.model.input_format | |
def resize_im(self, im): | |
H, W = im.shape[-2:] | |
if self.test_transform.is_range: | |
size = np.random.randint( | |
self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1) | |
else: | |
size = np.random.choice(self.test_transform.short_edge_length) | |
newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size) | |
return resize( | |
im, (newH, newW), antialias=True) | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
def forward(self, im: torch.Tensor): | |
assert im.ndim == 3 | |
if self.image_format == "BGR": | |
im = im.flip(0) | |
H, W = im.shape[-2:] | |
im = im.float() | |
im = self.resize_im(im) | |
inputs = dict(image=im, height=H, width=W) | |
# instances contains | |
# dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps']) | |
instances = self.model([inputs])[0]["instances"] | |
return dict( | |
scores=instances.get("scores").cpu(), | |
segmentation=instances.get("pred_masks").cpu(), | |
keypoints=instances.get("pred_keypoints").cpu() | |
) | |