File size: 453 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch


def normalize(v: torch.Tensor) -> torch.Tensor:
    return v / torch.linalg.norm(v, dim=-1, keepdim=True)


def cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
    return torch.stack(
        [
            v1[..., 1] * v2[..., 2] - v2[..., 1] * v1[..., 2],
            -(v1[..., 0] * v2[..., 2] - v2[..., 0] * v1[..., 2]),
            v1[..., 0] * v2[..., 1] - v2[..., 0] * v1[..., 1],
        ],
        dim=-1,
    )