Spaces:
Runtime error
Runtime error
File size: 9,144 Bytes
753fd9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import numpy as np
'''
def keyp_rep_error_l1(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
# step 1: make sure that the hg prediction and barc are close
with torch.no_grad():
kp_weights = keyp_hourglass_scores
kp_weights[keyp_hourglass_scores<thr_kp] = 0
loss_keyp_rep = torch.mean((torch.abs((smpl_keyp_2d - keyp_hourglass)/512)).sum(dim=2)*kp_weights[:, :, 0])
return loss_keyp_rep
def keyp_rep_error(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
# step 1: make sure that the hg prediction and barc are close
with torch.no_grad():
kp_weights = keyp_hourglass_scores
kp_weights[keyp_hourglass_scores<thr_kp] = 0
# losses['kp_reproj']['value'] = torch.mean((((smpl_keyp_2d - keyp_reproj_init)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
loss_keyp_rep = torch.mean((((smpl_keyp_2d - keyp_hourglass)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
return loss_keyp_rep
'''
def leg_sideway_error(optimed_pose_with_glob):
assert optimed_pose_with_glob.shape[1] == 35
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
# leg_indices_right = np.asarray([8, 9, 10, 18, 19, 20]) # front, back
# leg_indices_left = np.asarray([12, 13, 14, 22, 23, 24]) # front, back
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
vec[2] = -1
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
loss_pose_legs_side = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
return loss_pose_legs_side
def leg_torsion_error(optimed_pose_with_glob):
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
vec_x[0] = 1 # in x direction
x_x_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec_x
x_x_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec_x
loss_pose_legs_torsion = (x_x_legs_left[:, 1]**2).mean() + (x_x_legs_right[:, 1]**2).mean()
return loss_pose_legs_torsion
def frontleg_walkingdir_error(optimed_pose_with_glob):
# this prior should only be used for standing poses!
leg_indices_right = np.asarray([7, 8, 9, 10]) # front, back
leg_indices_left = np.asarray([11, 12, 13, 14]) # front, back
relevant_back_indices = np.asarray([1, 2, 3, 4, 5, 6]) # np.asarray([6]) # back joint in the front
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
x0_rotmat_back = x0_rotmat[:, relevant_back_indices, :, :]
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
vec[2] = -1 # vector down
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
x0_back = x0_rotmat_back.reshape((-1, 3, 3))@vec
loss_pose_legs_side = (x0_legs_left[:, 0]**2).mean() + (x0_legs_right[:, 0]**2).mean() + (x0_back[:, 0]**2).mean() # penalize movement to front
return loss_pose_legs_side
def tail_sideway_error(optimed_pose_with_glob):
tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
'''vec[2] = -1
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()'''
vec[0] = -1
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
return loss_pose_tail_side
def tail_torsion_error(optimed_pose_with_glob):
tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
'''vec_x[0] = 1 # in x direction
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()'''
vec_x[2] = 1 # in y direction
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()
return loss_pose_tail_torsion
def spine_sideway_error(optimed_pose_with_glob):
tail_indices = np.asarray([1, 2, 3, 4, 5, 6]) # was wrong
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
vec[0] = -1
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
return loss_pose_tail_side
def spine_torsion_error(optimed_pose_with_glob):
tail_indices = np.asarray([1, 2, 3, 4, 5, 6])
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
vec_x[2] = 1 # vec_x[0] = 1 # in z direction
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean() # (x_x_tail[:, 1]**2).mean()
return loss_pose_tail_torsion
def fit_plane(points_npx3):
# remarks:
# visualization of the plane: debug_code/curve_fitting_v2.py
# theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
# remark: torch.svd is depreciated
# new plane equation:
# a(x−x0)+b(y−y0)+c(z−z0)=0
# ax+by+cz=d with d=ax0+by0+cz0
# z = (d-ax-by)/c
# here:
# a, b, c describe the plane normal
# d can be calculated (from a, b, c, x0, y0, z0)
# (x0, y0, z0) are the coordinates of a point on the
# plane, for example points_centroid
# (x, y, z) are the coordinates of a query point on the plane
#
# points_npx3: (n_points, 3)
# REMARK: this loss is not yet for batches!
# import pdb; pdb.set_trace()
# print('this loss is not yet for batches!')
assert (points_npx3.ndim == 2)
assert (points_npx3.shape[1] == 3)
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
points_centroid = torch.mean(points, dim=1)
input_svd = points - points_centroid[:, None]
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
plane_normal = U_svd[:, 2]
plane_squaredsumofdists = sigma_svd[2]
error = plane_squaredsumofdists
return points_centroid, plane_normal, error
def paws_to_groundplane_error(vertices, return_details=False):
# list of feet vertices (some of them)
# remark: we did annotate left indices and find the right insices using sym_ids_dict
# REMARK: this loss is not yet for batches!
# import pdb; pdb.set_trace()
# print('this loss is not yet for batches!')
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
assert vertices.shape[0] == 3889
assert vertices.shape[1] == 3
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
verts_paws = vertices[all_paw_vert_idxs, :]
plane_centroid, plane_normal, error = fit_plane(verts_paws)
if return_details:
return plane_centroid, plane_normal, error
else:
return error
def groundcontact_error(vertices, gclabels, return_details=False):
# import pdb; pdb.set_trace()
# REMARK: this loss is not yet for batches!
import pdb; pdb.set_trace()
print('this loss is not yet for batches!')
assert vertices.shape[0] == 3889
assert vertices.shape[1] == 3
verts_gc = vertices[gclabels, :]
plane_centroid, plane_normal, error = fit_plane(verts_gc)
if return_details:
return plane_centroid, plane_normal, error
else:
return error
|