File size: 6,407 Bytes
475d332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# coding: utf-8

"""

face detectoin and alignment using XPose

"""

import os
import pickle
import torch
import numpy as np
from PIL import Image
from torchvision.ops import nms
from collections import OrderedDict


def clean_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k[:7] == 'module.':
            k = k[7:]  # remove `module.`
        new_state_dict[k] = v
    return new_state_dict


from src.models.XPose import transforms as T
from src.models.XPose.models import build_model
from src.models.XPose.predefined_keypoints import *
from src.models.XPose.util import box_ops
from src.models.XPose.util.config import Config


class XPoseRunner(object):
    def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
        self.device_id = kwargs.get("device_id", 0)
        self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
        self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
        self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
        # Load cached embeddings if available
        try:
            with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
                self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
            with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
                self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
            print("Loaded cached embeddings from file.")
        except Exception:
            raise ValueError("Could not load clip embeddings from file, please check your file path.")

    def load_animal_model(self, model_config_path, model_checkpoint_path, device):
        args = Config.fromfile(model_config_path)
        args.device = device
        model = build_model(args)
        checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
        load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
        model.eval()
        return model

    def load_image(self, input_image):
        image_pil = input_image.convert("RGB")
        transform = T.Compose([
            T.RandomResize([800], max_size=1333),  # NOTE: fixed size to 800
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        image, _ = transform(image_pil, None)
        return image_pil, image

    def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
        instance_list = instance_text_prompt.split(',')

        if len(keypoint_text_prompt) == 9:
            # torch.Size([1, 512]) torch.Size([9, 512])
            ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
        elif len(keypoint_text_prompt) == 68:
            # torch.Size([1, 512]) torch.Size([68, 512])
            ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
        else:
            raise ValueError("Invalid number of keypoint embeddings.")
        target = {
            "instance_text_prompt": instance_list,
            "keypoint_text_prompt": keypoint_text_prompt,
            "object_embeddings_text": ins_text_embeddings.float(),
            "kpts_embeddings_text": torch.cat(
                (kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)),
                dim=0),
            "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device),
                                       torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
        }

        self.model = self.model.to(self.device)
        image = image.to(self.device)

        with torch.no_grad():
            with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
                outputs = self.model(image[None], [target])

        logits = outputs["pred_logits"].sigmoid()[0]
        boxes = outputs["pred_boxes"][0]
        keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]

        logits_filt = logits.cpu().clone()
        boxes_filt = boxes.cpu().clone()
        keypoints_filt = keypoints.cpu().clone()
        filt_mask = logits_filt.max(dim=1)[0] > box_threshold
        logits_filt = logits_filt[filt_mask]
        boxes_filt = boxes_filt[filt_mask]
        keypoints_filt = keypoints_filt[filt_mask]

        keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0],
                           iou_threshold=IoU_threshold)

        filtered_boxes = boxes_filt[keep_indices]
        filtered_keypoints = keypoints_filt[keep_indices]

        return filtered_boxes, filtered_keypoints

    def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
        if keypoint_text_example in globals():
            keypoint_dict = globals()[keypoint_text_example]
        elif instance_text_prompt in globals():
            keypoint_dict = globals()[instance_text_prompt]
        else:
            keypoint_dict = globals()["animal"]

        keypoint_text_prompt = keypoint_dict.get("keypoints")
        keypoint_skeleton = keypoint_dict.get("skeleton")

        image_pil, image = self.load_image(input_image)
        boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt,
                                                             box_threshold, IoU_threshold)

        size = image_pil.size
        H, W = size[1], size[0]
        keypoints_filt = keypoints_filt[0].squeeze(0)
        kp = np.array(keypoints_filt.cpu())
        num_kpts = len(keypoint_text_prompt)
        Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
        Z = Z.reshape(num_kpts * 2)
        x = Z[0::2]
        y = Z[1::2]
        return np.stack((x, y), axis=1)

    def warmup(self):
        img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
        self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)