LAM / vhap /util /vector_ops.py
yuandong513
feat: init
17cd746
raw
history blame contribute delete
624 Bytes
import torch
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.sum(x*y, -1, keepdim=True)
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
return 2*dot(x, n)*n - x
def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
return x / length(x, eps)
def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)