Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
import os.path as osp | |
from collections import defaultdict | |
import cv2 | |
import torch | |
import numpy as np | |
import scipy.signal as signal | |
from progress.bar import Bar | |
from scipy.ndimage.filters import gaussian_filter1d | |
from configs import constants as _C | |
from .backbone.hmr2 import hmr2 | |
from .backbone.utils import process_image | |
from ...utils.imutils import flip_kp, flip_bbox | |
ROOT_DIR = osp.abspath(f"{__file__}/../../../../") | |
class FeatureExtractor(object): | |
def __init__(self, device, flip_eval=False, max_batch_size=64): | |
self.device = device | |
self.flip_eval = flip_eval | |
self.max_batch_size = max_batch_size | |
ckpt = osp.join(ROOT_DIR, 'checkpoints', 'hmr2a.ckpt') | |
self.model = hmr2(ckpt).to(device).eval() | |
def run(self, video, tracking_results, patch_h=256, patch_w=256): | |
if osp.isfile(video): | |
cap = cv2.VideoCapture(video) | |
is_video = True | |
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
else: # Image list | |
cap = video | |
is_video = False | |
length = len(video) | |
height, width = cv2.imread(video[0]).shape[:2] | |
frame_id = 0 | |
bar = Bar('Feature extraction ...', fill='#', max=length) | |
while True: | |
if is_video: | |
flag, img = cap.read() | |
if not flag: | |
break | |
else: | |
if frame_id >= len(cap): | |
break | |
img = cv2.imread(cap[frame_id]) | |
for _id, val in tracking_results.items(): | |
if not frame_id in val['frame_id']: continue | |
frame_id2 = np.where(val['frame_id'] == frame_id)[0][0] | |
bbox = val['bbox'][frame_id2] | |
cx, cy, scale = bbox | |
norm_img, crop_img = process_image(img[..., ::-1], [cx, cy], scale, patch_h, patch_w) | |
norm_img = torch.from_numpy(norm_img).unsqueeze(0).to(self.device) | |
feature = self.model(norm_img, encode=True) | |
tracking_results[_id]['features'].append(feature.cpu()) | |
if frame_id2 == 0: # First frame of this subject | |
tracking_results = self.predict_init(norm_img, tracking_results, _id, flip_eval=False) | |
if self.flip_eval: | |
flipped_bbox = flip_bbox(bbox, width, height) | |
tracking_results[_id]['flipped_bbox'].append(flipped_bbox) | |
keypoints = val['keypoints'][frame_id2] | |
flipped_keypoints = flip_kp(keypoints, width) | |
tracking_results[_id]['flipped_keypoints'].append(flipped_keypoints) | |
flipped_features = self.model(torch.flip(norm_img, (3, )), encode=True) | |
tracking_results[_id]['flipped_features'].append(flipped_features.cpu()) | |
if frame_id2 == 0: | |
tracking_results = self.predict_init(torch.flip(norm_img, (3, )), tracking_results, _id, flip_eval=True) | |
bar.next() | |
frame_id += 1 | |
return self.process(tracking_results) | |
def predict_init(self, norm_img, tracking_results, _id, flip_eval=False): | |
prefix = 'flipped_' if flip_eval else '' | |
pred_global_orient, pred_body_pose, pred_betas, _ = self.model(norm_img, encode=False) | |
tracking_results[_id][prefix + 'init_global_orient'] = pred_global_orient.cpu() | |
tracking_results[_id][prefix + 'init_body_pose'] = pred_body_pose.cpu() | |
tracking_results[_id][prefix + 'init_betas'] = pred_betas.cpu() | |
return tracking_results | |
def process(self, tracking_results): | |
output = defaultdict(dict) | |
for _id, results in tracking_results.items(): | |
for key, val in results.items(): | |
if isinstance(val, list): | |
if isinstance(val[0], torch.Tensor): | |
val = torch.cat(val) | |
elif isinstance(val[0], np.ndarray): | |
val = np.array(val) | |
output[_id][key] = val | |
return output |