Spaces:
Runtime error
Runtime error
File size: 6,542 Bytes
97a6728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
from typing import List
import tops
from torchvision.transforms.functional import InterpolationMode, resize
from densepose.data.utils import get_class_to_mesh_name_mapping
from densepose import add_densepose_config
from densepose.structures import DensePoseEmbeddingPredictorOutput
from densepose.vis.extractor import DensePoseOutputsExtractor
from densepose.modeling import build_densepose_embedder
from detectron2.config import get_cfg
from detectron2.data.transforms import ResizeShortestEdge
from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model
model_urls = {
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
}
def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
assert len(S.shape) == 3
H, W = imshape
N = len(boxes_XYXY)
segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
boxes_XYXY = boxes_XYXY.long()
for i in range(N):
x0, y0, x1, y1 = boxes_XYXY[i]
assert x0 >= 0 and y0 >= 0
assert x1 <= imshape[1]
assert y1 <= imshape[0]
h = y1 - y0
w = x1 - x0
segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
return segmentation
class CSEDetector:
def __init__(
self,
cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
cfg_2_download: List[str] = [
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
score_thres: float = 0.9,
nms_thresh: float = None,
) -> None:
with tops.logger.capture_log_stdout():
cfg = get_cfg()
self.device = tops.get_device()
add_densepose_config(cfg)
cfg_path = tops.download_file(cfg_url)
for p in cfg_2_download:
tops.download_file(p)
with tops.logger.capture_log_stdout():
cfg.merge_from_file(cfg_path)
assert cfg_url in model_urls, cfg_url
model_path = tops.download_file(model_urls[cfg_url])
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
if nms_thresh is not None:
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
cfg.MODEL.WEIGHTS = str(model_path)
cfg.MODEL.DEVICE = str(self.device)
cfg.freeze()
with tops.logger.capture_log_stdout():
self.model = build_model(cfg)
self.model.eval()
DetectionCheckpointer(self.model).load(str(model_path))
self.input_format = cfg.INPUT.FORMAT
self.densepose_extractor = DensePoseOutputsExtractor()
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
self.embedder = build_densepose_embedder(cfg)
self.mesh_vertex_embeddings = {
mesh_name: self.embedder(mesh_name).to(self.device)
for mesh_name in self.class_to_mesh_name.values()
if self.embedder.has_embeddings(mesh_name)
}
self.cfg = cfg
self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
tops.logger.log("CSEDetector built.")
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def resize_im(self, im):
H, W = im.shape[1:]
newH, newW = ResizeShortestEdge.get_output_shape(
H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
return resize(
im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
@torch.no_grad()
def forward(self, im):
assert im.dtype == torch.uint8
if self.input_format == "BGR":
im = im.flip(0)
H, W = im.shape[1:]
im = self.resize_im(im)
output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
scores = output.get("scores")
if len(scores) == 0:
return dict(
instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
embed_map=self.mesh_vertex_embeddings["smpl_27554"],
bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
scores=torch.empty((0), dtype=torch.float, device=im.device)
)
pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
E = pred_densepose.embedding
mesh_name = self.class_to_mesh_name[classes[0]]
assert mesh_name == "smpl_27554"
x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
boxes_XYXY = boxes_XYXY.round_().long()
non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
S = S[non_empty_boxes]
E = E[non_empty_boxes]
boxes_XYXY = boxes_XYXY[non_empty_boxes]
scores = scores[non_empty_boxes]
im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
return dict(
instance_segmentation=S, instance_embedding=E,
bbox_XYXY=boxes_XYXY,
im_segmentation=im_segmentation,
scores=scores.view(-1))
|