File size: 9,334 Bytes
c87d1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os, sys

import torch
import numpy as np
from lib.utils import transforms

from smplx import SMPL as _SMPL
from smplx.utils import SMPLOutput as ModelOutput
from smplx.lbs import vertices2joints

from configs import constants as _C

class SMPL(_SMPL):
    """ Extension of the official SMPL implementation to support more joints """

    def __init__(self, *args, **kwargs):
        sys.stdout = open(os.devnull, 'w')
        super(SMPL, self).__init__(*args, **kwargs)
        sys.stdout = sys.__stdout__
        
        J_regressor_wham = np.load(_C.BMODEL.JOINTS_REGRESSOR_WHAM)
        J_regressor_eval = np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
        self.register_buffer('J_regressor_wham', torch.tensor(
            J_regressor_wham, dtype=torch.float32))
        self.register_buffer('J_regressor_eval', torch.tensor(
            J_regressor_eval, dtype=torch.float32))
        self.register_buffer('J_regressor_feet', torch.from_numpy(
            np.load(_C.BMODEL.JOINTS_REGRESSOR_FEET)
        ).float())
        
    def get_local_pose_from_reduced_global_pose(self, reduced_pose):
        full_pose = torch.eye(
            3, device=reduced_pose.device
        )[(None, ) * 2].repeat(reduced_pose.shape[0], 24, 1, 1)
        full_pose[:, _C.BMODEL.MAIN_JOINTS] = reduced_pose
        return full_pose

    def forward(self,

                pred_rot6d,

                betas, 

                cam=None, 

                cam_intrinsics=None, 

                bbox=None, 

                res=None,

                return_full_pose=False,

                **kwargs):
        
        rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
        ).reshape(-1, 24, 3, 3)

        output = self.get_output(body_pose=rotmat[:, 1:],
                                 global_orient=rotmat[:, :1],
                                 betas=betas.view(-1, 10),
                                 pose2rot=False,
                                 return_full_pose=return_full_pose)

        if cam is not None:
            joints3d = output.joints.reshape(*cam.shape[:2], -1, 3)
            
            # Weak perspective projection (for InstaVariety)
            weak_cam = convert_weak_perspective_to_perspective(cam)
            
            weak_joints2d = weak_perspective_projection(
                joints3d,
                rotation=torch.eye(3, device=cam.device).unsqueeze(0).unsqueeze(0).expand(*cam.shape[:2], -1, -1),
                translation=weak_cam,
                focal_length=5000.,
                camera_center=torch.zeros(*cam.shape[:2], 2, device=cam.device)
            )
            output.weak_joints2d = weak_joints2d
            
            # Full perspective projection
            full_cam = convert_pare_to_full_img_cam(
                cam, 
                bbox[:, :, 2] * 200., 
                bbox[:, :, :2], 
                res[:, 0].unsqueeze(-1), 
                res[:, 1].unsqueeze(-1), 
                focal_length=cam_intrinsics[:, :, 0, 0]
            )
            
            full_joints2d = full_perspective_projection(
                joints3d,
                translation=full_cam,
                cam_intrinsics=cam_intrinsics,
            )
            output.full_joints2d = full_joints2d
            output.full_cam = full_cam.reshape(-1, 3)
            
        return output
    
    def forward_nd(self, 

                pred_rot6d, 

                root,

                betas, 

                return_full_pose=False):
        
        rotmat = transforms.rotation_6d_to_matrix(pred_rot6d.reshape(*pred_rot6d.shape[:2], -1, 6)
        ).reshape(-1, 24, 3, 3)

        output = self.get_output(body_pose=rotmat[:, 1:],
                                 global_orient=root.reshape(-1, 1, 3, 3),
                                 betas=betas.view(-1, 10),
                                 pose2rot=False,
                                 return_full_pose=return_full_pose)

        return output

    def get_output(self, *args, **kwargs):
        kwargs['get_skin'] = True
        smpl_output = super(SMPL, self).forward(*args, **kwargs)
        joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
        feet = vertices2joints(self.J_regressor_feet, smpl_output.vertices)
        
        offset = joints[..., [11, 12], :].mean(-2)
        if 'transl' in kwargs:
            offset = offset - kwargs['transl']
        vertices = smpl_output.vertices - offset.unsqueeze(-2)
        joints = joints - offset.unsqueeze(-2)
        feet = feet - offset.unsqueeze(-2)

        output = ModelOutput(vertices=vertices,
                             global_orient=smpl_output.global_orient,
                             body_pose=smpl_output.body_pose,
                             joints=joints,
                             betas=smpl_output.betas,
                             full_pose=smpl_output.full_pose)
        output.feet = feet
        output.offset = offset
        return output
    
    def get_offset(self, *args, **kwargs):
        kwargs['get_skin'] = True
        smpl_output = super(SMPL, self).forward(*args, **kwargs)
        joints = vertices2joints(self.J_regressor_wham, smpl_output.vertices)
        
        offset = joints[..., [11, 12], :].mean(-2)
        return offset

    def get_faces(self):
        return np.array(self.faces)
    

def convert_weak_perspective_to_perspective(

        weak_perspective_camera,

        focal_length=5000.,

        img_res=224,

):
    
    perspective_camera = torch.stack(
        [
            weak_perspective_camera[..., 1],
            weak_perspective_camera[..., 2],
            2 * focal_length / (img_res * weak_perspective_camera[..., 0] + 1e-9)
        ],
        dim=-1
    )
    return perspective_camera    


def weak_perspective_projection(

        points, 

        rotation, 

        translation,

        focal_length, 

        camera_center, 

        img_res=224,

        normalize_joints2d=True,

):
    """

    This function computes the perspective projection of a set of points.

    Input:

        points (b, f, N, 3): 3D points

        rotation (b, f, 3, 3): Camera rotation

        translation (b, f, 3): Camera translation

        focal_length (b, f,) or scalar: Focal length

        camera_center (b, f, 2): Camera center

    """

    K = torch.zeros([*points.shape[:2], 3, 3], device=points.device)
    K[:,:,0,0] = focal_length
    K[:,:,1,1] = focal_length
    K[:,:,2,2] = 1.
    K[:,:,:-1, -1] = camera_center

    # Transform points
    points = torch.einsum('bfij,bfkj->bfki', rotation, points)
    points = points + translation.unsqueeze(-2)

    # Apply perspective distortion
    projected_points = points / points[...,-1].unsqueeze(-1)

    # Apply camera intrinsics
    projected_points = torch.einsum('bfij,bfkj->bfki', K, projected_points)
    
    if normalize_joints2d:
        projected_points = projected_points / (img_res / 2.) 

    return projected_points[..., :-1]

    
def full_perspective_projection(

        points, 

        cam_intrinsics, 

        rotation=None,

        translation=None,

):

    K = cam_intrinsics

    if rotation is not None:
        points = (rotation @ points.transpose(-1, -2)).transpose(-1, -2)
    if translation is not None:
        points = points + translation.unsqueeze(-2)
    projected_points = points / points[..., -1].unsqueeze(-1)
    projected_points = (K @ projected_points.transpose(-1, -2)).transpose(-1, -2)
    return projected_points[..., :-1]


def convert_pare_to_full_img_cam(

        pare_cam, 

        bbox_height, 

        bbox_center,

        img_w, 

        img_h, 

        focal_length, 

        crop_res=224

):

    s, tx, ty = pare_cam[..., 0], pare_cam[..., 1], pare_cam[..., 2]
    res = crop_res
    r = bbox_height / res
    tz = 2 * focal_length / (r * res * s)

    cx = 2 * (bbox_center[..., 0] - (img_w / 2.)) / (s * bbox_height)
    cy = 2 * (bbox_center[..., 1] - (img_h / 2.)) / (s * bbox_height)

    cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
    return cam_t


def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length):
    """

    convert the camera parameters from the crop camera to the full camera

    :param crop_cam: shape=(N, 3) weak perspective camera in cropped img coordinates (s, tx, ty)

    :param center: shape=(N, 2) bbox coordinates (c_x, c_y)

    :param scale: shape=(N) square bbox resolution  (b / 200)

    :param full_img_shape: shape=(N, 2) original image height and width

    :param focal_length: shape=(N,)

    :return:

    """
    img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1]
    cx, cy, b = center[:, 0], center[:, 1], scale * 200
    w_2, h_2 = img_w / 2., img_h / 2.
    bs = b * crop_cam[:, 0] + 1e-9
    tz = 2 * focal_length / bs
    tx = (2 * (cx - w_2) / bs) + crop_cam[:, 1]
    ty = (2 * (cy - h_2) / bs) + crop_cam[:, 2]
    full_cam = torch.stack([tx, ty, tz], dim=-1)
    return full_cam