Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Callable, Optional | |
import torch | |
class Query: | |
# Both of these are of shape [batch_size x ... x 3] | |
position: torch.Tensor | |
direction: Optional[torch.Tensor] = None | |
t_min: Optional[torch.Tensor] = None | |
t_max: Optional[torch.Tensor] = None | |
def copy(self) -> "Query": | |
return Query( | |
position=self.position, | |
direction=self.direction, | |
t_min=self.t_min, | |
t_max=self.t_max, | |
) | |
def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query": | |
return Query( | |
position=f(self.position), | |
direction=f(self.direction) if self.direction is not None else None, | |
t_min=f(self.t_min) if self.t_min is not None else None, | |
t_max=f(self.t_max) if self.t_max is not None else None, | |
) | |