Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,407 Bytes
475d332 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# coding: utf-8
"""
face detectoin and alignment using XPose
"""
import os
import pickle
import torch
import numpy as np
from PIL import Image
from torchvision.ops import nms
from collections import OrderedDict
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
from src.models.XPose import transforms as T
from src.models.XPose.models import build_model
from src.models.XPose.predefined_keypoints import *
from src.models.XPose.util import box_ops
from src.models.XPose.util.config import Config
class XPoseRunner(object):
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
self.device_id = kwargs.get("device_id", 0)
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
# Load cached embeddings if available
try:
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
print("Loaded cached embeddings from file.")
except Exception:
raise ValueError("Could not load clip embeddings from file, please check your file path.")
def load_animal_model(self, model_config_path, model_checkpoint_path, device):
args = Config.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(self, input_image):
image_pil = input_image.convert("RGB")
transform = T.Compose([
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None)
return image_pil, image
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
instance_list = instance_text_prompt.split(',')
if len(keypoint_text_prompt) == 9:
# torch.Size([1, 512]) torch.Size([9, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
elif len(keypoint_text_prompt) == 68:
# torch.Size([1, 512]) torch.Size([68, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
else:
raise ValueError("Invalid number of keypoint embeddings.")
target = {
"instance_text_prompt": instance_list,
"keypoint_text_prompt": keypoint_text_prompt,
"object_embeddings_text": ins_text_embeddings.float(),
"kpts_embeddings_text": torch.cat(
(kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)),
dim=0),
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device),
torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
}
self.model = self.model.to(self.device)
image = image.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
outputs = self.model(image[None], [target])
logits = outputs["pred_logits"].sigmoid()[0]
boxes = outputs["pred_boxes"][0]
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
logits_filt = logits.cpu().clone()
boxes_filt = boxes.cpu().clone()
keypoints_filt = keypoints.cpu().clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask]
boxes_filt = boxes_filt[filt_mask]
keypoints_filt = keypoints_filt[filt_mask]
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0],
iou_threshold=IoU_threshold)
filtered_boxes = boxes_filt[keep_indices]
filtered_keypoints = keypoints_filt[keep_indices]
return filtered_boxes, filtered_keypoints
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
if keypoint_text_example in globals():
keypoint_dict = globals()[keypoint_text_example]
elif instance_text_prompt in globals():
keypoint_dict = globals()[instance_text_prompt]
else:
keypoint_dict = globals()["animal"]
keypoint_text_prompt = keypoint_dict.get("keypoints")
keypoint_skeleton = keypoint_dict.get("skeleton")
image_pil, image = self.load_image(input_image)
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt,
box_threshold, IoU_threshold)
size = image_pil.size
H, W = size[1], size[0]
keypoints_filt = keypoints_filt[0].squeeze(0)
kp = np.array(keypoints_filt.cpu())
num_kpts = len(keypoint_text_prompt)
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
Z = Z.reshape(num_kpts * 2)
x = Z[0::2]
y = Z[1::2]
return np.stack((x, y), axis=1)
def warmup(self):
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
|