Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
import os.path as osp | |
from glob import glob | |
from collections import defaultdict | |
import cv2 | |
import torch | |
import joblib | |
import numpy as np | |
from loguru import logger | |
from progress.bar import Bar | |
from configs.config import get_cfg_defaults | |
from lib.data.datasets import CustomDataset | |
from lib.utils.imutils import avg_preds | |
from lib.utils.transforms import matrix_to_axis_angle | |
from lib.models import build_network, build_body_model | |
from lib.models.preproc.detector import DetectionModel | |
from lib.models.preproc.extractor import FeatureExtractor | |
from lib.models.smplify import TemporalSMPLify | |
try: | |
from lib.models.preproc.slam import SLAMModel | |
_run_global = True | |
except: | |
logger.info('DPVO is not properly installed. Only estimate in local coordinates !') | |
_run_global = False | |
def run(cfg, | |
video, | |
output_pth, | |
network, | |
calib=None, | |
run_global=True, | |
save_pkl=False, | |
visualize=False): | |
cap = cv2.VideoCapture(video) | |
assert cap.isOpened(), f'Faild to load video file {video}' | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
# Whether or not estimating motion in global coordinates | |
run_global = run_global and _run_global | |
# Preprocess | |
with torch.no_grad(): | |
if not (osp.exists(osp.join(output_pth, 'tracking_results.pth')) and | |
osp.exists(osp.join(output_pth, 'slam_results.pth'))): | |
detector = DetectionModel(cfg.DEVICE.lower()) | |
extractor = FeatureExtractor(cfg.DEVICE.lower(), cfg.FLIP_EVAL) | |
if run_global: slam = SLAMModel(video, output_pth, width, height, calib) | |
else: slam = None | |
bar = Bar('Preprocess: 2D detection and SLAM', fill='#', max=length) | |
while (cap.isOpened()): | |
flag, img = cap.read() | |
if not flag: break | |
# 2D detection and tracking | |
detector.track(img, fps, length) | |
# SLAM | |
if slam is not None: | |
slam.track() | |
bar.next() | |
tracking_results = detector.process(fps) | |
if slam is not None: | |
slam_results = slam.process() | |
else: | |
slam_results = np.zeros((length, 7)) | |
slam_results[:, 3] = 1.0 # Unit quaternion | |
# Extract image features | |
# TODO: Merge this into the previous while loop with an online bbox smoothing. | |
tracking_results = extractor.run(video, tracking_results) | |
logger.info('Complete Data preprocessing!') | |
# Save the processed data | |
joblib.dump(tracking_results, osp.join(output_pth, 'tracking_results.pth')) | |
joblib.dump(slam_results, osp.join(output_pth, 'slam_results.pth')) | |
logger.info(f'Save processed data at {output_pth}') | |
# If the processed data already exists, load the processed data | |
else: | |
tracking_results = joblib.load(osp.join(output_pth, 'tracking_results.pth')) | |
slam_results = joblib.load(osp.join(output_pth, 'slam_results.pth')) | |
logger.info(f'Already processed data exists at {output_pth} ! Load the data .') | |
# Build dataset | |
dataset = CustomDataset(cfg, tracking_results, slam_results, width, height, fps) | |
# run WHAM | |
results = defaultdict(dict) | |
n_subjs = len(dataset) | |
for subj in range(n_subjs): | |
with torch.no_grad(): | |
if cfg.FLIP_EVAL: | |
# Forward pass with flipped input | |
flipped_batch = dataset.load_data(subj, True) | |
_id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = flipped_batch | |
flipped_pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) | |
# Forward pass with normal input | |
batch = dataset.load_data(subj) | |
_id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = batch | |
pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) | |
# Merge two predictions | |
flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) | |
pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) | |
flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) | |
avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) | |
avg_pose = avg_pose.reshape(-1, 144) | |
avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2 | |
# Refine trajectory with merged prediction | |
network.pred_pose = avg_pose.view_as(network.pred_pose) | |
network.pred_shape = avg_shape.view_as(network.pred_shape) | |
network.pred_contact = avg_contact.view_as(network.pred_contact) | |
output = network.forward_smpl(**kwargs) | |
pred = network.refine_trajectory(output, cam_angvel, return_y_up=True) | |
else: | |
# data | |
batch = dataset.load_data(subj) | |
_id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = batch | |
# inference | |
pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) | |
# if False: | |
if args.run_smplify: | |
smplify = TemporalSMPLify(smpl, img_w=width, img_h=height, device=cfg.DEVICE) | |
input_keypoints = dataset.tracking_results[_id]['keypoints'] | |
pred = smplify.fit(pred, input_keypoints, **kwargs) | |
with torch.no_grad(): | |
network.pred_pose = pred['pose'] | |
network.pred_shape = pred['betas'] | |
network.pred_cam = pred['cam'] | |
output = network.forward_smpl(**kwargs) | |
pred = network.refine_trajectory(output, cam_angvel, return_y_up=True) | |
# ========= Store results ========= # | |
pred_body_pose = matrix_to_axis_angle(pred['poses_body']).cpu().numpy().reshape(-1, 69) | |
pred_root = matrix_to_axis_angle(pred['poses_root_cam']).cpu().numpy().reshape(-1, 3) | |
pred_root_world = matrix_to_axis_angle(pred['poses_root_world']).cpu().numpy().reshape(-1, 3) | |
pred_pose = np.concatenate((pred_root, pred_body_pose), axis=-1) | |
pred_pose_world = np.concatenate((pred_root_world, pred_body_pose), axis=-1) | |
pred_trans = (pred['trans_cam'] - network.output.offset).cpu().numpy() | |
results[_id]['pose'] = pred_pose | |
results[_id]['trans'] = pred_trans | |
results[_id]['pose_world'] = pred_pose_world | |
results[_id]['trans_world'] = pred['trans_world'].cpu().squeeze(0).numpy() | |
results[_id]['betas'] = pred['betas'].cpu().squeeze(0).numpy() | |
results[_id]['verts'] = (pred['verts_cam'] + pred['trans_cam'].unsqueeze(1)).cpu().numpy() | |
results[_id]['frame_ids'] = frame_id | |
if save_pkl: | |
joblib.dump(results, osp.join(output_pth, "wham_output.pkl")) | |
# Visualize | |
if visualize: | |
from lib.vis.run_vis import run_vis_on_demo | |
with torch.no_grad(): | |
run_vis_on_demo(cfg, video, results, output_pth, network.smpl, vis_global=run_global) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--video', type=str, | |
default='examples/demo_video.mp4', | |
help='input video path or youtube link') | |
parser.add_argument('--output_pth', type=str, default='output/demo', | |
help='output folder to write results') | |
parser.add_argument('--calib', type=str, default=None, | |
help='Camera calibration file path') | |
parser.add_argument('--estimate_local_only', action='store_true', | |
help='Only estimate motion in camera coordinate if True') | |
parser.add_argument('--visualize', action='store_true', | |
help='Visualize the output mesh if True') | |
parser.add_argument('--save_pkl', action='store_true', | |
help='Save output as pkl file') | |
parser.add_argument('--run_smplify', action='store_true', | |
help='Run Temporal SMPLify for post processing') | |
args = parser.parse_args() | |
cfg = get_cfg_defaults() | |
cfg.merge_from_file('configs/yamls/demo.yaml') | |
logger.info(f'GPU name -> {torch.cuda.get_device_name()}') | |
logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') | |
# ========= Load WHAM ========= # | |
smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN | |
smpl = build_body_model(cfg.DEVICE, smpl_batch_size) | |
network = build_network(cfg, smpl) | |
network.eval() | |
# Output folder | |
sequence = '.'.join(args.video.split('/')[-1].split('.')[:-1]) | |
output_pth = osp.join(args.output_pth, sequence) | |
os.makedirs(output_pth, exist_ok=True) | |
run(cfg, | |
args.video, | |
output_pth, | |
network, | |
args.calib, | |
run_global=not args.estimate_local_only, | |
save_pkl=args.save_pkl, | |
visualize=args.visualize) | |
print() | |
logger.info('Done !') |