Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
""" | |
face detectoin and alignment using XPose | |
""" | |
import os | |
import pickle | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision.ops import nms | |
from collections import OrderedDict | |
def clean_state_dict(state_dict): | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # remove `module.` | |
new_state_dict[k] = v | |
return new_state_dict | |
from src.models.XPose import transforms as T | |
from src.models.XPose.models import build_model | |
from src.models.XPose.predefined_keypoints import * | |
from src.models.XPose.util import box_ops | |
from src.models.XPose.util.config import Config | |
class XPoseRunner(object): | |
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs): | |
self.device_id = kwargs.get("device_id", 0) | |
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True) | |
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu" | |
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device) | |
# Load cached embeddings if available | |
try: | |
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f: | |
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f) | |
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f: | |
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f) | |
print("Loaded cached embeddings from file.") | |
except Exception: | |
raise ValueError("Could not load clip embeddings from file, please check your file path.") | |
def load_animal_model(self, model_config_path, model_checkpoint_path, device): | |
args = Config.fromfile(model_config_path) | |
args.device = device | |
model = build_model(args) | |
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage) | |
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | |
model.eval() | |
return model | |
def load_image(self, input_image): | |
image_pil = input_image.convert("RGB") | |
transform = T.Compose([ | |
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800 | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
image, _ = transform(image_pil, None) | |
return image_pil, image | |
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): | |
instance_list = instance_text_prompt.split(',') | |
if len(keypoint_text_prompt) == 9: | |
# torch.Size([1, 512]) torch.Size([9, 512]) | |
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9 | |
elif len(keypoint_text_prompt) == 68: | |
# torch.Size([1, 512]) torch.Size([68, 512]) | |
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68 | |
else: | |
raise ValueError("Invalid number of keypoint embeddings.") | |
target = { | |
"instance_text_prompt": instance_list, | |
"keypoint_text_prompt": keypoint_text_prompt, | |
"object_embeddings_text": ins_text_embeddings.float(), | |
"kpts_embeddings_text": torch.cat( | |
(kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), | |
dim=0), | |
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), | |
torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0) | |
} | |
self.model = self.model.to(self.device) | |
image = image.to(self.device) | |
with torch.no_grad(): | |
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision): | |
outputs = self.model(image[None], [target]) | |
logits = outputs["pred_logits"].sigmoid()[0] | |
boxes = outputs["pred_boxes"][0] | |
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)] | |
logits_filt = logits.cpu().clone() | |
boxes_filt = boxes.cpu().clone() | |
keypoints_filt = keypoints.cpu().clone() | |
filt_mask = logits_filt.max(dim=1)[0] > box_threshold | |
logits_filt = logits_filt[filt_mask] | |
boxes_filt = boxes_filt[filt_mask] | |
keypoints_filt = keypoints_filt[filt_mask] | |
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], | |
iou_threshold=IoU_threshold) | |
filtered_boxes = boxes_filt[keep_indices] | |
filtered_keypoints = keypoints_filt[keep_indices] | |
return filtered_boxes, filtered_keypoints | |
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold): | |
if keypoint_text_example in globals(): | |
keypoint_dict = globals()[keypoint_text_example] | |
elif instance_text_prompt in globals(): | |
keypoint_dict = globals()[instance_text_prompt] | |
else: | |
keypoint_dict = globals()["animal"] | |
keypoint_text_prompt = keypoint_dict.get("keypoints") | |
keypoint_skeleton = keypoint_dict.get("skeleton") | |
image_pil, image = self.load_image(input_image) | |
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, | |
box_threshold, IoU_threshold) | |
size = image_pil.size | |
H, W = size[1], size[0] | |
keypoints_filt = keypoints_filt[0].squeeze(0) | |
kp = np.array(keypoints_filt.cpu()) | |
num_kpts = len(keypoint_text_prompt) | |
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts) | |
Z = Z.reshape(num_kpts * 2) | |
x = Z[0::2] | |
y = Z[1::2] | |
return np.stack((x, y), axis=1) | |
def warmup(self): | |
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0) | |