Techt3o's picture
ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe
c87d1bc verified
raw
history blame
4.61 kB
from __future__ import annotations
import os
import os.path as osp
from collections import defaultdict
import cv2
import torch
import numpy as np
import scipy.signal as signal
from progress.bar import Bar
from scipy.ndimage.filters import gaussian_filter1d
from configs import constants as _C
from .backbone.hmr2 import hmr2
from .backbone.utils import process_image
from ...utils.imutils import flip_kp, flip_bbox
ROOT_DIR = osp.abspath(f"{__file__}/../../../../")
class FeatureExtractor(object):
def __init__(self, device, flip_eval=False, max_batch_size=64):
self.device = device
self.flip_eval = flip_eval
self.max_batch_size = max_batch_size
ckpt = osp.join(ROOT_DIR, 'checkpoints', 'hmr2a.ckpt')
self.model = hmr2(ckpt).to(device).eval()
def run(self, video, tracking_results, patch_h=256, patch_w=256):
if osp.isfile(video):
cap = cv2.VideoCapture(video)
is_video = True
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
else: # Image list
cap = video
is_video = False
length = len(video)
height, width = cv2.imread(video[0]).shape[:2]
frame_id = 0
bar = Bar('Feature extraction ...', fill='#', max=length)
while True:
if is_video:
flag, img = cap.read()
if not flag:
break
else:
if frame_id >= len(cap):
break
img = cv2.imread(cap[frame_id])
for _id, val in tracking_results.items():
if not frame_id in val['frame_id']: continue
frame_id2 = np.where(val['frame_id'] == frame_id)[0][0]
bbox = val['bbox'][frame_id2]
cx, cy, scale = bbox
norm_img, crop_img = process_image(img[..., ::-1], [cx, cy], scale, patch_h, patch_w)
norm_img = torch.from_numpy(norm_img).unsqueeze(0).to(self.device)
feature = self.model(norm_img, encode=True)
tracking_results[_id]['features'].append(feature.cpu())
if frame_id2 == 0: # First frame of this subject
tracking_results = self.predict_init(norm_img, tracking_results, _id, flip_eval=False)
if self.flip_eval:
flipped_bbox = flip_bbox(bbox, width, height)
tracking_results[_id]['flipped_bbox'].append(flipped_bbox)
keypoints = val['keypoints'][frame_id2]
flipped_keypoints = flip_kp(keypoints, width)
tracking_results[_id]['flipped_keypoints'].append(flipped_keypoints)
flipped_features = self.model(torch.flip(norm_img, (3, )), encode=True)
tracking_results[_id]['flipped_features'].append(flipped_features.cpu())
if frame_id2 == 0:
tracking_results = self.predict_init(torch.flip(norm_img, (3, )), tracking_results, _id, flip_eval=True)
bar.next()
frame_id += 1
return self.process(tracking_results)
def predict_init(self, norm_img, tracking_results, _id, flip_eval=False):
prefix = 'flipped_' if flip_eval else ''
pred_global_orient, pred_body_pose, pred_betas, _ = self.model(norm_img, encode=False)
tracking_results[_id][prefix + 'init_global_orient'] = pred_global_orient.cpu()
tracking_results[_id][prefix + 'init_body_pose'] = pred_body_pose.cpu()
tracking_results[_id][prefix + 'init_betas'] = pred_betas.cpu()
return tracking_results
def process(self, tracking_results):
output = defaultdict(dict)
for _id, results in tracking_results.items():
for key, val in results.items():
if isinstance(val, list):
if isinstance(val[0], torch.Tensor):
val = torch.cat(val)
elif isinstance(val[0], np.ndarray):
val = np.array(val)
output[_id][key] = val
return output