test_kdtalker / difpoint /src /utils /animal_landmark_runner.py
YinuoGuo27's picture
Upload 96 files
02f8487 verified
# coding: utf-8
"""
face detectoin and alignment using XPose
"""
import os
import pickle
import torch
import numpy as np
from PIL import Image
from torchvision.ops import nms
from collections import OrderedDict
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
from src.models.XPose import transforms as T
from src.models.XPose.models import build_model
from src.models.XPose.predefined_keypoints import *
from src.models.XPose.util import box_ops
from src.models.XPose.util.config import Config
class XPoseRunner(object):
def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
self.device_id = kwargs.get("device_id", 0)
self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
# Load cached embeddings if available
try:
with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
print("Loaded cached embeddings from file.")
except Exception:
raise ValueError("Could not load clip embeddings from file, please check your file path.")
def load_animal_model(self, model_config_path, model_checkpoint_path, device):
args = Config.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(self, input_image):
image_pil = input_image.convert("RGB")
transform = T.Compose([
T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None)
return image_pil, image
def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
instance_list = instance_text_prompt.split(',')
if len(keypoint_text_prompt) == 9:
# torch.Size([1, 512]) torch.Size([9, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
elif len(keypoint_text_prompt) == 68:
# torch.Size([1, 512]) torch.Size([68, 512])
ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
else:
raise ValueError("Invalid number of keypoint embeddings.")
target = {
"instance_text_prompt": instance_list,
"keypoint_text_prompt": keypoint_text_prompt,
"object_embeddings_text": ins_text_embeddings.float(),
"kpts_embeddings_text": torch.cat(
(kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)),
dim=0),
"kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device),
torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
}
self.model = self.model.to(self.device)
image = image.to(self.device)
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
outputs = self.model(image[None], [target])
logits = outputs["pred_logits"].sigmoid()[0]
boxes = outputs["pred_boxes"][0]
keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
logits_filt = logits.cpu().clone()
boxes_filt = boxes.cpu().clone()
keypoints_filt = keypoints.cpu().clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask]
boxes_filt = boxes_filt[filt_mask]
keypoints_filt = keypoints_filt[filt_mask]
keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0],
iou_threshold=IoU_threshold)
filtered_boxes = boxes_filt[keep_indices]
filtered_keypoints = keypoints_filt[keep_indices]
return filtered_boxes, filtered_keypoints
def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
if keypoint_text_example in globals():
keypoint_dict = globals()[keypoint_text_example]
elif instance_text_prompt in globals():
keypoint_dict = globals()[instance_text_prompt]
else:
keypoint_dict = globals()["animal"]
keypoint_text_prompt = keypoint_dict.get("keypoints")
keypoint_skeleton = keypoint_dict.get("skeleton")
image_pil, image = self.load_image(input_image)
boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt,
box_threshold, IoU_threshold)
size = image_pil.size
H, W = size[1], size[0]
keypoints_filt = keypoints_filt[0].squeeze(0)
kp = np.array(keypoints_filt.cpu())
num_kpts = len(keypoint_text_prompt)
Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
Z = Z.reshape(num_kpts * 2)
x = Z[0::2]
y = Z[1::2]
return np.stack((x, y), axis=1)
def warmup(self):
img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)