Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import os.path as osp | |
from glob import glob | |
from collections import defaultdict | |
import cv2 | |
import torch | |
import pickle | |
import joblib | |
import argparse | |
import numpy as np | |
from loguru import logger | |
from progress.bar import Bar | |
from configs import constants as _C | |
from lib.models.smpl import SMPL | |
from lib.models.preproc.extractor import FeatureExtractor | |
from lib.models.preproc.backbone.utils import process_image | |
from lib.utils import transforms | |
from lib.utils.imutils import ( | |
flip_kp, flip_bbox | |
) | |
dataset = defaultdict(list) | |
detection_results_dir = 'dataset/detection_results/3DPW' | |
tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_dset_db.pt' | |
def preprocess(dset, batch_size): | |
if dset == 'val': _dset = 'validation' | |
else: _dset = dset | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_{dset}_vit.pth') # Use ViT feature extractor | |
extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size) | |
tcmr_data = joblib.load(tcmr_annot_pth.replace('dset', dset)) | |
smpl_neutral = SMPL(model_path=_C.BMODEL.FLDR) | |
annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True) | |
idxs = idxs.tolist() | |
annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)] | |
annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', _dset, annot_file[:-2] + '.pkl') for annot_file in annot_file_list] | |
annot_file_list = list(dict.fromkeys(annot_file_list)) | |
for annot_file in annot_file_list: | |
seq = annot_file.split('/')[-1].split('.')[0] | |
data = pickle.load(open(annot_file, 'rb'), encoding='latin1') | |
num_people = len(data['poses']) | |
num_frames = len(data['img_frame_ids']) | |
assert (data['poses2d'][0].shape[0] == num_frames) | |
K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float() | |
for p_id in range(num_people): | |
logger.info(f'==> {seq} {p_id}') | |
gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]] | |
# ======== Add TCMR data ======== # | |
vid_name = f'{seq}_{p_id}' | |
tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v] | |
frame_ids = tcmr_data['frame_id'][tcmr_ids] | |
pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids] | |
shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1) | |
pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() # Camera coordinate | |
cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float() | |
# ======== Get detection results ======== # | |
fname = f'{seq}_{p_id}.npy' | |
pred_kp2d = torch.from_numpy( | |
np.load(osp.join(detection_results_dir, fname)) | |
).float()[frame_ids] | |
# ======== Get detection results ======== # | |
img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg'))) | |
img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids] | |
img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2] | |
vid_idxs = fname.split('.')[0] | |
# ======== Append data ======== # | |
dataset['gender'].append(gender) | |
dataset['vid'].append(vid_idxs) | |
dataset['pose'].append(pose) | |
dataset['betas'].append(shape) | |
dataset['cam_poses'].append(cam_poses) | |
dataset['frame_id'].append(torch.from_numpy(frame_ids)) | |
dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float()) | |
dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float()) | |
dataset['kp2d'].append(pred_kp2d) | |
# Flipped data | |
dataset['flipped_bbox'].append( | |
torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float() | |
) | |
dataset['flipped_kp2d'].append( | |
torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float() | |
) | |
# ======== Append data ======== # | |
# ======== Extract features ======== # | |
patch_list = [] | |
bboxes = dataset['bbox'][-1].clone().numpy() | |
bar = Bar(f'Load images', fill='#', max=len(img_paths)) | |
for img_path, bbox in zip(img_paths, bboxes): | |
img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) | |
norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256) | |
patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float()) | |
bar.next() | |
patch_list = torch.split(torch.cat(patch_list), batch_size) | |
features, flipped_features = [], [] | |
for i, patch in enumerate(patch_list): | |
feature = extractor.model(patch.cuda(), encode=True) | |
features.append(feature.cpu()) | |
flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True) | |
flipped_features.append(flipped_feature.cpu()) | |
if i == 0: | |
init_patch = patch[[0]].clone() | |
features = torch.cat(features) | |
flipped_features = torch.cat(flipped_features) | |
dataset['features'].append(features) | |
dataset['flipped_features'].append(flipped_features) | |
# ======== Extract features ======== # | |
# Pad 1 frame | |
for key, val in dataset.items(): | |
if isinstance(val[-1], torch.Tensor): | |
dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0) | |
# Initial predictions | |
bbox = torch.from_numpy(bboxes[:1].copy()).float().cuda() | |
bbox_center = bbox[:, :2].clone() | |
bbox_scale = bbox[:, 2].clone() / 200 | |
kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(), | |
'img_h': torch.tensor(res_h).repeat(1).float().cuda(), | |
'bbox_center': bbox_center, 'bbox_scale': bbox_scale} | |
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs) | |
pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(), | |
body_pose=pred_pose.cpu(), | |
betas=pred_shape.cpu(), | |
pose2rot=False) | |
init_kp3d = pred_output.joints | |
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) | |
dataset['init_kp3d'].append(init_kp3d) | |
dataset['init_pose'].append(init_pose.cpu()) | |
# Flipped initial predictions | |
bbox_center[:, 0] = res_w - bbox_center[:, 0] | |
pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs) | |
pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(), | |
body_pose=pred_pose.cpu(), | |
betas=pred_shape.cpu(), | |
pose2rot=False) | |
init_kp3d = pred_output.joints | |
init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) | |
dataset['flipped_init_kp3d'].append(init_kp3d) | |
dataset['flipped_init_pose'].append(init_pose.cpu()) | |
joblib.dump(dataset, save_pth) | |
logger.info(f'\n ==> Done !') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-s', '--split', type=str, choices=['val', 'test'], help='Data split') | |
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') | |
args = parser.parse_args() | |
preprocess(args.split, args.batch_size) |