|
import torch |
|
|
|
|
|
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
return torch.sum(x*y, -1, keepdim=True) |
|
|
|
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: |
|
return 2*dot(x, n)*n - x |
|
|
|
def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: |
|
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) |
|
|
|
def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: |
|
return x / length(x, eps) |
|
|
|
def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: |
|
return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) |
|
|