Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers import ModelMixin | |
from torch import Tensor | |
from .temporaltrans.temptrans import SimpleTemperalPointModel, SimpleTransModel | |
class PointModel(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
model_type: str = 'pvcnn', | |
in_channels: int = 3, | |
out_channels: int = 3, | |
embed_dim: int = 64, | |
dropout: float = 0.1, | |
width_multiplier: int = 1, | |
voxel_resolution_multiplier: int = 1, | |
): | |
super().__init__() | |
self.model_type = model_type | |
if self.model_type == 'simple': | |
self.autocast_context = torch.autocast('cuda', dtype=torch.float32) | |
self.model = SimpleTransModel( | |
embed_dim=embed_dim, | |
num_classes=out_channels, | |
extra_feature_channels=(in_channels - 3), | |
) | |
self.model.output_projection.bias.data.normal_(0, 1e-6) | |
self.model.output_projection.weight.data.normal_(0, 1e-6) | |
else: | |
raise NotImplementedError() | |
def forward(self, inputs: Tensor, t: Tensor, context=None) -> Tensor: | |
""" Receives input of shape (B, N, in_channels) and returns output | |
of shape (B, N, out_channels) """ | |
with self.autocast_context: | |
return self.model(inputs, t, context) | |