# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead import torch from .backbone import ViT import torchvision.transforms.functional as F import torch.nn as nn import tops model_large = dict( type="TopDown", pretrained=None, backbone=dict( type="ViT", img_size=(256, 192), patch_size=16, embed_dim=1024, depth=24, num_heads=16, ratio=1, use_checkpoint=False, mlp_ratio=4, qkv_bias=True, drop_path_rate=0.5, ), keypoint_head=dict( type="TopdownHeatmapSimpleHead", in_channels=1024, num_deconv_layers=2, num_deconv_filters=(256, 256), num_deconv_kernels=(4, 4), extra=dict( final_conv_kernel=1, ), out_channels=17, loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), ), train_cfg=dict(), test_cfg=dict( flip_test=True, post_process="default", shift_heatmap=False, target_type="GaussianHeatmap", modulate_kernel=11, use_udp=True, ), ) model_base = dict( type="TopDown", pretrained=None, backbone=dict( type="ViT", img_size=(256, 192), patch_size=16, embed_dim=768, depth=12, num_heads=12, ratio=1, use_checkpoint=False, mlp_ratio=4, qkv_bias=True, drop_path_rate=0.3, ), keypoint_head=dict( type="TopdownHeatmapSimpleHead", in_channels=768, num_deconv_layers=2, num_deconv_filters=(256, 256), num_deconv_kernels=(4, 4), extra=dict( final_conv_kernel=1, ), out_channels=17, loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), ), train_cfg=dict(), test_cfg=dict(), ) model_huge = dict( type='TopDown', pretrained=None, backbone=dict( type='ViT', img_size=(256, 192), patch_size=16, embed_dim=1280, depth=32, num_heads=16, ratio=1, use_checkpoint=False, mlp_ratio=4, qkv_bias=True, drop_path_rate=0.55, ), keypoint_head=dict( type='TopdownHeatmapSimpleHead', in_channels=1280, num_deconv_layers=2, num_deconv_filters=(256, 256), num_deconv_kernels=(4, 4), extra=dict(final_conv_kernel=1, ), out_channels=17, loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), train_cfg=dict(), test_cfg=dict( flip_test=True, post_process='default', shift_heatmap=False, target_type="GaussianHeatmap", modulate_kernel=11, use_udp=True)) class VitPoseModel(nn.Module): def __init__(self, model_name): super().__init__() assert model_name in ["vit_base", "vit_large", "vit_huge"] model = { "vit_base": model_base, "vit_large": model_large, "vit_huge": model_huge }[model_name] weight_url = { "vit_base": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/90235a26-3b8c-427d-a264-c68155abecdcfcfcd8a9-0388-4575-b85b-607d3c0a9b149bef8f0f-a0f9-4662-a561-1b47ba5f1636", "vit_large": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a580a44c-0afd-43ac-a2cb-9956c32b1d1a78c51ecb-81bb-4345-8710-13904cb9dbbe0703db2d-8534-42e0-ac4d-518ab51fe7db", "vit_huge": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a33b6ada-4d2f-4ef7-8f83-b33f58b69f5b2a62e181-2131-467d-a900-027157a08571d761fad4-785b-4b84-8596-8932c7857e44" }[model_name] file_name = { "vit_base": "vit-b-multi-coco-595b5e128b.pth", "vit_large": "vit-l-multi-coco-9475d27cec.pth", "vit_huge": "vit-h-multi-coco-dbc06d4337.pth", }[model_name] # Set check_hash to true if you suspect a download error. weight_path = tops.download_file( weight_url, file_name=file_name, check_hash=True) self.keypoint_head = tops.to_cuda(TopdownHeatmapSimpleHead( in_channels=model["keypoint_head"]["in_channels"], out_channels=model["keypoint_head"]["out_channels"], num_deconv_filters=model["keypoint_head"]["num_deconv_filters"], num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"], num_deconv_layers=model["keypoint_head"]["num_deconv_layers"], extra=model["keypoint_head"]["extra"], )) # print(head) self.backbone = tops.to_cuda(ViT( img_size=model["backbone"]["img_size"], patch_size=model["backbone"]["patch_size"], embed_dim=model["backbone"]["embed_dim"], depth=model["backbone"]["depth"], num_heads=model["backbone"]["num_heads"], ratio=model["backbone"]["ratio"], mlp_ratio=model["backbone"]["mlp_ratio"], qkv_bias=model["backbone"]["qkv_bias"], drop_path_rate=model["backbone"]["drop_path_rate"], )) ckpt = torch.load(weight_path, map_location=tops.get_device()) self.load_state_dict(ckpt["state_dict"]) self.backbone.eval() self.keypoint_head.eval() def forward(self, img: torch.Tensor, boxes_ltrb: torch.Tensor): assert img.ndim == 3 assert img.dtype == torch.uint8 assert boxes_ltrb.ndim == 2 and boxes_ltrb.shape[1] == 4 assert boxes_ltrb.dtype == torch.long boxes_ltrb = boxes_ltrb.clamp(0) padded_boxes = torch.zeros_like(boxes_ltrb) images = torch.zeros((len(boxes_ltrb), 3, 256, 192), device=img.device, dtype=torch.float32) for i, (x0, y0, x1, y1) in enumerate(boxes_ltrb): x1 = min(img.shape[-1], x1) y1 = min(img.shape[-2], y1) correction_factor = 256 / 192 * (x1 - x0) / (y1 - y0) if correction_factor > 1: # increase y side center = y0 + (y1 - y0) // 2 length = (y1-y0).mul(correction_factor).round().long() y0_new = center - length.div(2).long() y1_new = center + length.div(2).long() image_crop = img[:, y0:y1, x0:x1] # print(y1,y2,x1,x2) pad = ((y0_new-y0).abs(), (y1_new-y1).abs()) # pad = (int(abs(y0_new-y0))), int(abs(y1_new-y1)) image_crop = torch.nn.functional.pad(image_crop, [*(0, 0), *pad]) padded_boxes[i] = torch.tensor([x0, y0_new, x1, y1_new]) else: center = x0 + (x1 - x0) // 2 length = (x1-x0).div(correction_factor).round().long() x0_new = center - length.div(2).long() x1_new = center + length.div(2).long() image_crop = img[:, y0:y1, x0:x1] pad = ((x0_new-x0).abs(), (x1_new-x1).abs()) image_crop = torch.nn.functional.pad(image_crop, [*pad, ]) padded_boxes[i] = torch.tensor([x0_new, y0, x1_new, y1]) image_crop = F.resize(image_crop.float(), (256, 192), antialias=True) image_crop = F.normalize(image_crop, mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) images[i] = image_crop x = self.backbone(images) out = self.keypoint_head(x) pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32, device=img.device) # For each human, for each joint: y, x, confidence b, indices = torch.max(out, dim=2) b, indices = torch.max(b, dim=2) c, indicesc = torch.max(out, dim=3) c, indicesc = torch.max(c, dim=2) dim1 = torch.tensor(1./64, device=img.device) dim2 = torch.tensor(1./48, device=img.device) for i in range(0, out.shape[0]): pts[i, :, 0] = indicesc[i, :] * dim1 * (padded_boxes[i][3] - padded_boxes[i][1]) + padded_boxes[i][1] pts[i, :, 1] = indices[i, :] * dim2 * (padded_boxes[i][2] - padded_boxes[i][0]) + padded_boxes[i][0] pts[i, :, 2] = c[i, :] pts = pts[:, :, [1, 0, 2]] return pts