import os import argparse import os.path as osp from glob import glob from collections import defaultdict import cv2 import torch import joblib import numpy as np from loguru import logger from progress.bar import Bar from configs.config import get_cfg_defaults from lib.data.datasets import CustomDataset from lib.utils.imutils import avg_preds from lib.utils.transforms import matrix_to_axis_angle from lib.models import build_network, build_body_model from lib.models.preproc.detector import DetectionModel from lib.models.preproc.extractor import FeatureExtractor from lib.models.smplify import TemporalSMPLify try: from lib.models.preproc.slam import SLAMModel _run_global = True except: logger.info('DPVO is not properly installed. Only estimate in local coordinates !') _run_global = False def run(cfg, video, output_pth, network, calib=None, run_global=True, save_pkl=False, visualize=False): cap = cv2.VideoCapture(video) assert cap.isOpened(), f'Faild to load video file {video}' fps = cap.get(cv2.CAP_PROP_FPS) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # Whether or not estimating motion in global coordinates run_global = run_global and _run_global # Preprocess with torch.no_grad(): if not (osp.exists(osp.join(output_pth, 'tracking_results.pth')) and osp.exists(osp.join(output_pth, 'slam_results.pth'))): detector = DetectionModel(cfg.DEVICE.lower()) extractor = FeatureExtractor(cfg.DEVICE.lower(), cfg.FLIP_EVAL) if run_global: slam = SLAMModel(video, output_pth, width, height, calib) else: slam = None bar = Bar('Preprocess: 2D detection and SLAM', fill='#', max=length) while (cap.isOpened()): flag, img = cap.read() if not flag: break # 2D detection and tracking detector.track(img, fps, length) # SLAM if slam is not None: slam.track() bar.next() tracking_results = detector.process(fps) if slam is not None: slam_results = slam.process() else: slam_results = np.zeros((length, 7)) slam_results[:, 3] = 1.0 # Unit quaternion # Extract image features # TODO: Merge this into the previous while loop with an online bbox smoothing. tracking_results = extractor.run(video, tracking_results) logger.info('Complete Data preprocessing!') # Save the processed data joblib.dump(tracking_results, osp.join(output_pth, 'tracking_results.pth')) joblib.dump(slam_results, osp.join(output_pth, 'slam_results.pth')) logger.info(f'Save processed data at {output_pth}') # If the processed data already exists, load the processed data else: tracking_results = joblib.load(osp.join(output_pth, 'tracking_results.pth')) slam_results = joblib.load(osp.join(output_pth, 'slam_results.pth')) logger.info(f'Already processed data exists at {output_pth} ! Load the data .') # Build dataset dataset = CustomDataset(cfg, tracking_results, slam_results, width, height, fps) # run WHAM results = defaultdict(dict) n_subjs = len(dataset) for subj in range(n_subjs): with torch.no_grad(): if cfg.FLIP_EVAL: # Forward pass with flipped input flipped_batch = dataset.load_data(subj, True) _id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = flipped_batch flipped_pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) # Forward pass with normal input batch = dataset.load_data(subj) _id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = batch pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) # Merge two predictions flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) avg_pose = avg_pose.reshape(-1, 144) avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2 # Refine trajectory with merged prediction network.pred_pose = avg_pose.view_as(network.pred_pose) network.pred_shape = avg_shape.view_as(network.pred_shape) network.pred_contact = avg_contact.view_as(network.pred_contact) output = network.forward_smpl(**kwargs) pred = network.refine_trajectory(output, cam_angvel, return_y_up=True) else: # data batch = dataset.load_data(subj) _id, x, inits, features, mask, init_root, cam_angvel, frame_id, kwargs = batch # inference pred = network(x, inits, features, mask=mask, init_root=init_root, cam_angvel=cam_angvel, return_y_up=True, **kwargs) # if False: if args.run_smplify: smplify = TemporalSMPLify(smpl, img_w=width, img_h=height, device=cfg.DEVICE) input_keypoints = dataset.tracking_results[_id]['keypoints'] pred = smplify.fit(pred, input_keypoints, **kwargs) with torch.no_grad(): network.pred_pose = pred['pose'] network.pred_shape = pred['betas'] network.pred_cam = pred['cam'] output = network.forward_smpl(**kwargs) pred = network.refine_trajectory(output, cam_angvel, return_y_up=True) # ========= Store results ========= # pred_body_pose = matrix_to_axis_angle(pred['poses_body']).cpu().numpy().reshape(-1, 69) pred_root = matrix_to_axis_angle(pred['poses_root_cam']).cpu().numpy().reshape(-1, 3) pred_root_world = matrix_to_axis_angle(pred['poses_root_world']).cpu().numpy().reshape(-1, 3) pred_pose = np.concatenate((pred_root, pred_body_pose), axis=-1) pred_pose_world = np.concatenate((pred_root_world, pred_body_pose), axis=-1) pred_trans = (pred['trans_cam'] - network.output.offset).cpu().numpy() results[_id]['pose'] = pred_pose results[_id]['trans'] = pred_trans results[_id]['pose_world'] = pred_pose_world results[_id]['trans_world'] = pred['trans_world'].cpu().squeeze(0).numpy() results[_id]['betas'] = pred['betas'].cpu().squeeze(0).numpy() results[_id]['verts'] = (pred['verts_cam'] + pred['trans_cam'].unsqueeze(1)).cpu().numpy() results[_id]['frame_ids'] = frame_id if save_pkl: joblib.dump(results, osp.join(output_pth, "wham_output.pkl")) # Visualize if visualize: from lib.vis.run_vis import run_vis_on_demo with torch.no_grad(): run_vis_on_demo(cfg, video, results, output_pth, network.smpl, vis_global=run_global) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--video', type=str, default='examples/demo_video.mp4', help='input video path or youtube link') parser.add_argument('--output_pth', type=str, default='output/demo', help='output folder to write results') parser.add_argument('--calib', type=str, default=None, help='Camera calibration file path') parser.add_argument('--estimate_local_only', action='store_true', help='Only estimate motion in camera coordinate if True') parser.add_argument('--visualize', action='store_true', help='Visualize the output mesh if True') parser.add_argument('--save_pkl', action='store_true', help='Save output as pkl file') parser.add_argument('--run_smplify', action='store_true', help='Run Temporal SMPLify for post processing') args = parser.parse_args() cfg = get_cfg_defaults() cfg.merge_from_file('configs/yamls/demo.yaml') logger.info(f'GPU name -> {torch.cuda.get_device_name()}') logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') # ========= Load WHAM ========= # smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN smpl = build_body_model(cfg.DEVICE, smpl_batch_size) network = build_network(cfg, smpl) network.eval() # Output folder sequence = '.'.join(args.video.split('/')[-1].split('.')[:-1]) output_pth = osp.join(args.output_pth, sequence) os.makedirs(output_pth, exist_ok=True) run(cfg, args.video, output_pth, network, args.calib, run_global=not args.estimate_local_only, save_pkl=args.save_pkl, visualize=args.visualize) print() logger.info('Done !')