import torch def qinv(q: torch.Tensor) -> torch.Tensor: assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' mask = torch.ones_like(q) mask[..., 1:] = -mask[..., 1:] return q * mask def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """ Rotate vector(s) v about the rotation described by quaternion(s) q. Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, where * denotes any number of dimensions. Returns a tensor of shape (*, 3). """ assert q.shape[-1] == 4 assert v.shape[-1] == 3 assert q.shape[:-1] == v.shape[:-1] original_shape = list(v.shape) q = q.contiguous().view(-1, 4) v = v.contiguous().view(-1, 3) qvec = q[:, 1:] uv = torch.cross(qvec, v, dim=1) uuv = torch.cross(qvec, uv, dim=1) return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)