|
import numpy as np |
|
import torch |
|
|
|
|
|
def batch_mm(matrix, matrix_batch): |
|
""" |
|
https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242 |
|
:param matrix: Sparse or dense matrix, size (m, n). |
|
:param matrix_batch: Batched dense matrices, size (b, n, k). |
|
:return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k). |
|
""" |
|
batch_size = matrix_batch.shape[0] |
|
|
|
vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1) |
|
|
|
|
|
|
|
return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0) |
|
|
|
|
|
def aa2quat(rots, form='wxyz', unified_orient=True): |
|
""" |
|
Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0) |
|
@param rots: angle-axis rotations, (*, 3) |
|
@param form: quaternion format, either 'wxyz' or 'xyzw' |
|
@param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3) |
|
:return: |
|
""" |
|
angles = rots.norm(dim=-1, keepdim=True) |
|
norm = angles.clone() |
|
norm[norm < 1e-8] = 1 |
|
axis = rots / norm |
|
quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype) |
|
angles = angles * 0.5 |
|
if form == 'wxyz': |
|
quats[..., 0] = torch.cos(angles.squeeze(-1)) |
|
quats[..., 1:] = torch.sin(angles) * axis |
|
elif form == 'xyzw': |
|
quats[..., :3] = torch.sin(angles) * axis |
|
quats[..., 3] = torch.cos(angles.squeeze(-1)) |
|
|
|
if unified_orient: |
|
idx = quats[..., 0] < 0 |
|
quats[idx, :] *= -1 |
|
|
|
return quats |
|
|
|
|
|
def quat2aa(quats): |
|
""" |
|
Convert wxyz quaternions to angle-axis representation |
|
:param quats: |
|
:return: |
|
""" |
|
_cos = quats[..., 0] |
|
xyz = quats[..., 1:] |
|
_sin = xyz.norm(dim=-1) |
|
norm = _sin.clone() |
|
norm[norm < 1e-7] = 1 |
|
axis = xyz / norm.unsqueeze(-1) |
|
angle = torch.atan2(_sin, _cos) * 2 |
|
return axis * angle.unsqueeze(-1) |
|
|
|
|
|
def quat2mat(quats: torch.Tensor): |
|
""" |
|
Convert (w, x, y, z) quaternions to 3x3 rotation matrix |
|
:param quats: quaternions of shape (..., 4) |
|
:return: rotation matrices of shape (..., 3, 3) |
|
""" |
|
qw = quats[..., 0] |
|
qx = quats[..., 1] |
|
qy = quats[..., 2] |
|
qz = quats[..., 3] |
|
|
|
x2 = qx + qx |
|
y2 = qy + qy |
|
z2 = qz + qz |
|
xx = qx * x2 |
|
yy = qy * y2 |
|
wx = qw * x2 |
|
xy = qx * y2 |
|
yz = qy * z2 |
|
wy = qw * y2 |
|
xz = qx * z2 |
|
zz = qz * z2 |
|
wz = qw * z2 |
|
|
|
m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype) |
|
m[..., 0, 0] = 1.0 - (yy + zz) |
|
m[..., 0, 1] = xy - wz |
|
m[..., 0, 2] = xz + wy |
|
m[..., 1, 0] = xy + wz |
|
m[..., 1, 1] = 1.0 - (xx + zz) |
|
m[..., 1, 2] = yz - wx |
|
m[..., 2, 0] = xz - wy |
|
m[..., 2, 1] = yz + wx |
|
m[..., 2, 2] = 1.0 - (xx + yy) |
|
|
|
return m |
|
|
|
|
|
def quat2euler(q, order='xyz', degrees=True): |
|
""" |
|
Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output. |
|
""" |
|
q0 = q[..., 0] |
|
q1 = q[..., 1] |
|
q2 = q[..., 2] |
|
q3 = q[..., 3] |
|
es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype) |
|
|
|
if order == 'xyz': |
|
es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) |
|
es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1)) |
|
es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) |
|
else: |
|
raise NotImplementedError('Cannot convert to ordering %s' % order) |
|
|
|
if degrees: |
|
es = es * 180 / np.pi |
|
|
|
return es |
|
|
|
|
|
def euler2mat(rots, order='xyz'): |
|
axis = {'x': torch.tensor((1, 0, 0), device=rots.device), |
|
'y': torch.tensor((0, 1, 0), device=rots.device), |
|
'z': torch.tensor((0, 0, 1), device=rots.device)} |
|
|
|
rots = rots / 180 * np.pi |
|
mats = [] |
|
for i in range(3): |
|
aa = axis[order[i]] * rots[..., i].unsqueeze(-1) |
|
mats.append(aa2mat(aa)) |
|
return mats[0] @ (mats[1] @ mats[2]) |
|
|
|
|
|
def aa2mat(rots): |
|
""" |
|
Convert angle-axis representation to rotation matrix |
|
:param rots: angle-axis representation |
|
:return: |
|
""" |
|
quat = aa2quat(rots) |
|
mat = quat2mat(quat) |
|
return mat |
|
|
|
|
|
def mat2quat(R) -> torch.Tensor: |
|
''' |
|
https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py |
|
Convert a rotation matrix to a unit quaternion. |
|
|
|
This uses the Shepperd’s method for numerical stability. |
|
''' |
|
|
|
|
|
|
|
w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]) |
|
x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2]) |
|
y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2]) |
|
z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2]) |
|
|
|
yz = (R[..., 1, 2] + R[..., 2, 1]) |
|
xz = (R[..., 2, 0] + R[..., 0, 2]) |
|
xy = (R[..., 0, 1] + R[..., 1, 0]) |
|
|
|
wx = (R[..., 2, 1] - R[..., 1, 2]) |
|
wy = (R[..., 0, 2] - R[..., 2, 0]) |
|
wz = (R[..., 1, 0] - R[..., 0, 1]) |
|
|
|
w = torch.empty_like(x2) |
|
x = torch.empty_like(x2) |
|
y = torch.empty_like(x2) |
|
z = torch.empty_like(x2) |
|
|
|
flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1]) |
|
flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1]) |
|
flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1]) |
|
flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1]) |
|
|
|
x[flagA] = torch.sqrt(x2[flagA]) |
|
w[flagA] = wx[flagA] / x[flagA] |
|
y[flagA] = xy[flagA] / x[flagA] |
|
z[flagA] = xz[flagA] / x[flagA] |
|
|
|
y[flagB] = torch.sqrt(y2[flagB]) |
|
w[flagB] = wy[flagB] / y[flagB] |
|
x[flagB] = xy[flagB] / y[flagB] |
|
z[flagB] = yz[flagB] / y[flagB] |
|
|
|
z[flagC] = torch.sqrt(z2[flagC]) |
|
w[flagC] = wz[flagC] / z[flagC] |
|
x[flagC] = xz[flagC] / z[flagC] |
|
y[flagC] = yz[flagC] / z[flagC] |
|
|
|
w[flagD] = torch.sqrt(w2[flagD]) |
|
x[flagD] = wx[flagD] / w[flagD] |
|
y[flagD] = wy[flagD] / w[flagD] |
|
z[flagD] = wz[flagD] / w[flagD] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = [w, x, y, z] |
|
res = [z.unsqueeze(-1) for z in res] |
|
|
|
return torch.cat(res, dim=-1) / 2 |
|
|
|
|
|
def quat2repr6d(quat): |
|
mat = quat2mat(quat) |
|
res = mat[..., :2, :] |
|
res = res.reshape(res.shape[:-2] + (6, )) |
|
return res |
|
|
|
|
|
def repr6d2mat(repr): |
|
x = repr[..., :3] |
|
y = repr[..., 3:] |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
z = torch.cross(x, y) |
|
z = z / z.norm(dim=-1, keepdim=True) |
|
y = torch.cross(z, x) |
|
res = [x, y, z] |
|
res = [v.unsqueeze(-2) for v in res] |
|
mat = torch.cat(res, dim=-2) |
|
return mat |
|
|
|
|
|
def repr6d2quat(repr) -> torch.Tensor: |
|
x = repr[..., :3] |
|
y = repr[..., 3:] |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
z = torch.cross(x, y) |
|
z = z / z.norm(dim=-1, keepdim=True) |
|
y = torch.cross(z, x) |
|
res = [x, y, z] |
|
res = [v.unsqueeze(-2) for v in res] |
|
mat = torch.cat(res, dim=-2) |
|
return mat2quat(mat) |
|
|
|
|
|
def inv_affine(mat): |
|
""" |
|
Calculate the inverse of any affine transformation |
|
""" |
|
affine = torch.zeros((mat.shape[:2] + (1, 4))) |
|
affine[..., 3] = 1 |
|
vert_mat = torch.cat((mat, affine), dim=2) |
|
vert_mat_inv = torch.inverse(vert_mat) |
|
return vert_mat_inv[..., :3, :] |
|
|
|
|
|
def inv_rigid_affine(mat): |
|
""" |
|
Calculate the inverse of a rigid affine transformation |
|
""" |
|
res = mat.clone() |
|
res[..., :3] = mat[..., :3].transpose(-2, -1) |
|
res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1) |
|
return res |
|
|
|
|
|
def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None): |
|
if n_bone is None: n_bone = 24 |
|
if ee is not None: |
|
if root_rot: |
|
ee.append(0) |
|
n_bone_ = n_bone |
|
n_bone = len(ee) |
|
axis = torch.randn((batch_size, n_bone, 3), device=device) |
|
axis /= axis.norm(dim=-1, keepdim=True) |
|
if uniform: |
|
angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi |
|
else: |
|
angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor |
|
angle.clamp(-np.pi, np.pi) |
|
poses = axis * angle |
|
if ee is not None: |
|
res = torch.zeros((batch_size, n_bone_, 3), device=device) |
|
for i, id in enumerate(ee): |
|
res[:, id] = poses[:, i] |
|
poses = res |
|
poses = poses.reshape(batch_size, -1) |
|
if not root_rot: |
|
poses[..., :3] = 0 |
|
return poses |
|
|
|
|
|
def slerp(l, r, t, unit=True): |
|
""" |
|
:param l: shape = (*, n) |
|
:param r: shape = (*, n) |
|
:param t: shape = (*) |
|
:param unit: If l and h are unit vectors |
|
:return: |
|
""" |
|
eps = 1e-8 |
|
if not unit: |
|
l_n = l / torch.norm(l, dim=-1, keepdim=True) |
|
r_n = r / torch.norm(r, dim=-1, keepdim=True) |
|
else: |
|
l_n = l |
|
r_n = r |
|
omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1)) |
|
dom = torch.sin(omega) |
|
|
|
flag = dom < eps |
|
|
|
res = torch.empty_like(l_n) |
|
t_t = t[flag].unsqueeze(-1) |
|
res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag] |
|
|
|
flag = ~ flag |
|
|
|
t_t = t[flag] |
|
d_t = dom[flag] |
|
va = torch.sin((1 - t_t) * omega[flag]) / d_t |
|
vb = torch.sin(t_t * omega[flag]) / d_t |
|
res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag]) |
|
return res |
|
|
|
|
|
def slerp_quat(l, r, t): |
|
""" |
|
slerp for unit quaternions |
|
:param l: (*, 4) unit quaternion |
|
:param r: (*, 4) unit quaternion |
|
:param t: (*) scalar between 0 and 1 |
|
""" |
|
t = t.expand(l.shape[:-1]) |
|
flag = (l * r).sum(dim=-1) >= 0 |
|
res = torch.empty_like(l) |
|
res[flag] = slerp(l[flag], r[flag], t[flag]) |
|
flag = ~ flag |
|
res[flag] = slerp(-l[flag], r[flag], t[flag]) |
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_6d(input, size): |
|
""" |
|
:param input: (batch_size, n_channels, length) |
|
:param size: required output size for temporal axis |
|
:return: |
|
""" |
|
batch = input.shape[0] |
|
length = input.shape[-1] |
|
input = input.reshape((batch, -1, 6, length)) |
|
input = input.permute(0, 1, 3, 2) |
|
input_q = repr6d2quat(input) |
|
idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1) |
|
idx_l = torch.floor(idx) |
|
t = idx - idx_l |
|
idx_l = idx_l.long() |
|
idx_r = idx_l + 1 |
|
t = t.reshape((1, 1, -1)) |
|
res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t) |
|
res = quat2repr6d(res_q) |
|
res = res.permute(0, 1, 3, 2) |
|
res = res.reshape((batch, -1, size)) |
|
return res |
|
|