silentchen's picture
first commit
19c4ddf
raw
history blame
893 Bytes
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class Query:
# Both of these are of shape [batch_size x ... x 3]
position: torch.Tensor
direction: Optional[torch.Tensor] = None
t_min: Optional[torch.Tensor] = None
t_max: Optional[torch.Tensor] = None
def copy(self) -> "Query":
return Query(
position=self.position,
direction=self.direction,
t_min=self.t_min,
t_max=self.t_max,
)
def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query":
return Query(
position=f(self.position),
direction=f(self.direction) if self.direction is not None else None,
t_min=f(self.t_min) if self.t_min is not None else None,
t_max=f(self.t_max) if self.t_max is not None else None,
)