File size: 4,609 Bytes
c87d1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

import os
import os.path as osp
from collections import defaultdict

import cv2
import torch
import numpy as np
import scipy.signal as signal
from progress.bar import Bar
from scipy.ndimage.filters import gaussian_filter1d

from configs import constants as _C
from .backbone.hmr2 import hmr2
from .backbone.utils import process_image
from ...utils.imutils import flip_kp, flip_bbox

ROOT_DIR = osp.abspath(f"{__file__}/../../../../")

class FeatureExtractor(object):
    def __init__(self, device, flip_eval=False, max_batch_size=64):
        
        self.device = device
        self.flip_eval = flip_eval
        self.max_batch_size = max_batch_size
        
        ckpt = osp.join(ROOT_DIR, 'checkpoints', 'hmr2a.ckpt')
        self.model = hmr2(ckpt).to(device).eval()
    
    def run(self, video, tracking_results, patch_h=256, patch_w=256):
        
        if osp.isfile(video):
            cap = cv2.VideoCapture(video)
            is_video = True
            length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            width, height = cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        else:   # Image list
            cap = video
            is_video = False
            length = len(video)
            height, width = cv2.imread(video[0]).shape[:2]
        
        frame_id = 0
        bar = Bar('Feature extraction ...', fill='#', max=length)
        while True:
            if is_video:
                flag, img = cap.read()
                if not flag:
                    break
            else:
                if frame_id >= len(cap):
                    break
                img = cv2.imread(cap[frame_id])
            
            for _id, val in tracking_results.items():
                if not frame_id in val['frame_id']: continue
                
                frame_id2 = np.where(val['frame_id'] == frame_id)[0][0]
                bbox = val['bbox'][frame_id2]
                cx, cy, scale = bbox
                
                norm_img, crop_img = process_image(img[..., ::-1], [cx, cy], scale, patch_h, patch_w)
                norm_img = torch.from_numpy(norm_img).unsqueeze(0).to(self.device)
                feature = self.model(norm_img, encode=True)
                tracking_results[_id]['features'].append(feature.cpu())
                
                if frame_id2 == 0: # First frame of this subject
                    tracking_results = self.predict_init(norm_img, tracking_results, _id, flip_eval=False)
                    
                if self.flip_eval:
                    flipped_bbox = flip_bbox(bbox, width, height)
                    tracking_results[_id]['flipped_bbox'].append(flipped_bbox)
                    
                    keypoints = val['keypoints'][frame_id2]
                    flipped_keypoints = flip_kp(keypoints, width)
                    tracking_results[_id]['flipped_keypoints'].append(flipped_keypoints)
                    
                    flipped_features = self.model(torch.flip(norm_img, (3, )), encode=True)
                    tracking_results[_id]['flipped_features'].append(flipped_features.cpu())
                    
                    if frame_id2 == 0:
                        tracking_results = self.predict_init(torch.flip(norm_img, (3, )), tracking_results, _id, flip_eval=True)
                    
            bar.next()
            frame_id += 1
        
        return self.process(tracking_results)
    
    def predict_init(self, norm_img, tracking_results, _id, flip_eval=False):
        prefix = 'flipped_' if flip_eval else ''
        
        pred_global_orient, pred_body_pose, pred_betas, _ = self.model(norm_img, encode=False)
        tracking_results[_id][prefix + 'init_global_orient'] = pred_global_orient.cpu()
        tracking_results[_id][prefix + 'init_body_pose'] = pred_body_pose.cpu()
        tracking_results[_id][prefix + 'init_betas'] = pred_betas.cpu()
        return tracking_results
    
    def process(self, tracking_results):
        output = defaultdict(dict)
        
        for _id, results in tracking_results.items():
            
            for key, val in results.items():
                if isinstance(val, list):
                    if isinstance(val[0], torch.Tensor):
                        val = torch.cat(val)
                    elif isinstance(val[0], np.ndarray):
                        val = np.array(val)
                output[_id][key] = val
        
        return output