Spaces:
Runtime error
Runtime error
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from shap_e.util.collections import AttrDict | |
| from .meta import MetaModule, subdict | |
| from .pointnet2_utils import sample_and_group, sample_and_group_all | |
| def gelu(x): | |
| return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| def swish(x): | |
| return x * torch.sigmoid(x) | |
| def quick_gelu(x): | |
| return x * torch.sigmoid(1.702 * x) | |
| def torch_gelu(x): | |
| return torch.nn.functional.gelu(x) | |
| def geglu(x): | |
| v, gates = x.chunk(2, dim=-1) | |
| return v * gelu(gates) | |
| class SirenSin: | |
| def __init__(self, w0=30.0): | |
| self.w0 = w0 | |
| def __call__(self, x): | |
| return torch.sin(self.w0 * x) | |
| def get_act(name): | |
| return { | |
| "relu": torch.nn.functional.relu, | |
| "leaky_relu": torch.nn.functional.leaky_relu, | |
| "swish": swish, | |
| "tanh": torch.tanh, | |
| "gelu": gelu, | |
| "quick_gelu": quick_gelu, | |
| "torch_gelu": torch_gelu, | |
| "gelu2": quick_gelu, | |
| "geglu": geglu, | |
| "sigmoid": torch.sigmoid, | |
| "sin": torch.sin, | |
| "sin30": SirenSin(w0=30.0), | |
| "softplus": F.softplus, | |
| "exp": torch.exp, | |
| "identity": lambda x: x, | |
| }[name] | |
| def zero_init(affine): | |
| nn.init.constant_(affine.weight, 0.0) | |
| if affine.bias is not None: | |
| nn.init.constant_(affine.bias, 0.0) | |
| def siren_init_first_layer(affine, init_scale: float = 1.0): | |
| n_input = affine.weight.shape[1] | |
| u = init_scale / n_input | |
| nn.init.uniform_(affine.weight, -u, u) | |
| if affine.bias is not None: | |
| nn.init.constant_(affine.bias, 0.0) | |
| def siren_init(affine, coeff=1.0, init_scale: float = 1.0): | |
| n_input = affine.weight.shape[1] | |
| u = init_scale * np.sqrt(6.0 / n_input) / coeff | |
| nn.init.uniform_(affine.weight, -u, u) | |
| if affine.bias is not None: | |
| nn.init.constant_(affine.bias, 0.0) | |
| def siren_init_30(affine, init_scale: float = 1.0): | |
| siren_init(affine, coeff=30.0, init_scale=init_scale) | |
| def std_init(affine, init_scale: float = 1.0): | |
| n_in = affine.weight.shape[1] | |
| stddev = init_scale / math.sqrt(n_in) | |
| nn.init.normal_(affine.weight, std=stddev) | |
| if affine.bias is not None: | |
| nn.init.constant_(affine.bias, 0.0) | |
| def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0): | |
| if init == "siren30": | |
| for idx, affine in enumerate(affines): | |
| init = siren_init_first_layer if idx == 0 else siren_init_30 | |
| init(affine, init_scale=init_scale) | |
| elif init == "siren": | |
| for idx, affine in enumerate(affines): | |
| init = siren_init_first_layer if idx == 0 else siren_init | |
| init(affine, init_scale=init_scale) | |
| elif init is None: | |
| for affine in affines: | |
| std_init(affine, init_scale=init_scale) | |
| else: | |
| raise NotImplementedError(init) | |
| class MetaLinear(MetaModule): | |
| def __init__( | |
| self, | |
| n_in, | |
| n_out, | |
| bias: bool = True, | |
| meta_scale: bool = True, | |
| meta_shift: bool = True, | |
| meta_proj: bool = False, | |
| meta_bias: bool = False, | |
| trainable_meta: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| # n_in, n_out, bias=bias) | |
| register_meta_fn = ( | |
| self.register_meta_parameter if trainable_meta else self.register_meta_buffer | |
| ) | |
| if meta_scale: | |
| register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs))) | |
| if meta_shift: | |
| register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs))) | |
| register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn | |
| register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs))) | |
| if not bias: | |
| self.register_parameter("bias", None) | |
| else: | |
| register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn | |
| register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs))) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| # from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear | |
| # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with | |
| # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see | |
| # https://github.com/pytorch/pytorch/issues/57109 | |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
| if self.bias is not None: | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| nn.init.uniform_(self.bias, -bound, bound) | |
| def _bcast(self, op, left, right): | |
| if right.ndim == 2: | |
| # Has dimension [batch x d_output] | |
| right = right.unsqueeze(1) | |
| return op(left, right) | |
| def forward(self, x, params=None): | |
| params = self.update(params) | |
| batch_size, *shape, d_in = x.shape | |
| x = x.view(batch_size, -1, d_in) | |
| if params.weight.ndim == 2: | |
| h = torch.einsum("bni,oi->bno", x, params.weight) | |
| elif params.weight.ndim == 3: | |
| h = torch.einsum("bni,boi->bno", x, params.weight) | |
| if params.bias is not None: | |
| h = self._bcast(torch.add, h, params.bias) | |
| if params.scale is not None: | |
| h = self._bcast(torch.mul, h, params.scale) | |
| if params.shift is not None: | |
| h = self._bcast(torch.add, h, params.shift) | |
| h = h.view(batch_size, *shape, -1) | |
| return h | |
| def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs): | |
| cls = { | |
| 1: nn.Conv1d, | |
| 2: nn.Conv2d, | |
| 3: nn.Conv3d, | |
| }[n_dim] | |
| return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs) | |
| def flatten(x): | |
| batch_size, *shape, n_channels = x.shape | |
| n_ctx = np.prod(shape) | |
| return x.view(batch_size, n_ctx, n_channels), AttrDict( | |
| shape=shape, n_ctx=n_ctx, n_channels=n_channels | |
| ) | |
| def unflatten(x, info): | |
| batch_size = x.shape[0] | |
| return x.view(batch_size, *info.shape, info.n_channels) | |
| def torchify(x): | |
| extent = list(range(1, x.ndim - 1)) | |
| return x.permute([0, x.ndim - 1, *extent]) | |
| def untorchify(x): | |
| extent = list(range(2, x.ndim)) | |
| return x.permute([0, *extent, 1]) | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| d_input: int, | |
| d_hidden: List[int], | |
| d_output: int, | |
| act_name: str = "quick_gelu", | |
| bias: bool = True, | |
| init: Optional[str] = None, | |
| init_scale: float = 1.0, | |
| zero_out: bool = False, | |
| ): | |
| """ | |
| Required: d_input, d_hidden, d_output | |
| Optional: act_name, bias | |
| """ | |
| super().__init__() | |
| ds = [d_input] + d_hidden + [d_output] | |
| affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])] | |
| self.d = ds | |
| self.affines = nn.ModuleList(affines) | |
| self.act = get_act(act_name) | |
| mlp_init(self.affines, init=init, init_scale=init_scale) | |
| if zero_out: | |
| zero_init(affines[-1]) | |
| def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""): | |
| options = AttrDict() if options is None else AttrDict(options) | |
| *hid, out = self.affines | |
| for i, f in enumerate(hid): | |
| h = self.act(f(h)) | |
| h = out(h) | |
| return h | |
| class MetaMLP(MetaModule): | |
| def __init__( | |
| self, | |
| d_input: int, | |
| d_hidden: List[int], | |
| d_output: int, | |
| act_name: str = "quick_gelu", | |
| bias: bool = True, | |
| meta_scale: bool = True, | |
| meta_shift: bool = True, | |
| meta_proj: bool = False, | |
| meta_bias: bool = False, | |
| trainable_meta: bool = False, | |
| init: Optional[str] = None, | |
| init_scale: float = 1.0, | |
| zero_out: bool = False, | |
| ): | |
| super().__init__() | |
| ds = [d_input] + d_hidden + [d_output] | |
| affines = [ | |
| MetaLinear( | |
| d_in, | |
| d_out, | |
| bias=bias, | |
| meta_scale=meta_scale, | |
| meta_shift=meta_shift, | |
| meta_proj=meta_proj, | |
| meta_bias=meta_bias, | |
| trainable_meta=trainable_meta, | |
| ) | |
| for d_in, d_out in zip(ds[:-1], ds[1:]) | |
| ] | |
| self.d = ds | |
| self.affines = nn.ModuleList(affines) | |
| self.act = get_act(act_name) | |
| mlp_init(affines, init=init, init_scale=init_scale) | |
| if zero_out: | |
| zero_init(affines[-1]) | |
| def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""): | |
| options = AttrDict() if options is None else AttrDict(options) | |
| params = self.update(params) | |
| *hid, out = self.affines | |
| for i, layer in enumerate(hid): | |
| h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}"))) | |
| last = len(self.affines) - 1 | |
| h = out(h, params=subdict(params, f"{log_prefix}affines.{last}")) | |
| return h | |
| class LayerNorm(nn.LayerNorm): | |
| def __init__( | |
| self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True | |
| ): | |
| super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine) | |
| self.width = np.prod(norm_shape) | |
| self.max_numel = 65535 * self.width | |
| def forward(self, input): | |
| if input.numel() > self.max_numel: | |
| return F.layer_norm( | |
| input.float(), self.normalized_shape, self.weight, self.bias, self.eps | |
| ).type_as(input) | |
| else: | |
| return super(LayerNorm, self).forward(input.float()).type_as(input) | |
| class PointSetEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| radius: float, | |
| n_point: int, | |
| n_sample: int, | |
| d_input: int, | |
| d_hidden: List[int], | |
| patch_size: int = 1, | |
| stride: int = 1, | |
| activation: str = "swish", | |
| group_all: bool = False, | |
| padding_mode: str = "zeros", | |
| fps_method: str = "fps", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.n_point = n_point | |
| self.radius = radius | |
| self.n_sample = n_sample | |
| self.mlp_convs = nn.ModuleList() | |
| self.act = get_act(activation) | |
| self.patch_size = patch_size | |
| self.stride = stride | |
| last_channel = d_input + 3 | |
| for out_channel in d_hidden: | |
| self.mlp_convs.append( | |
| nn.Conv2d( | |
| last_channel, | |
| out_channel, | |
| kernel_size=(patch_size, 1), | |
| stride=(stride, 1), | |
| padding=(patch_size // 2, 0), | |
| padding_mode=padding_mode, | |
| **kwargs, | |
| ) | |
| ) | |
| last_channel = out_channel | |
| self.group_all = group_all | |
| self.fps_method = fps_method | |
| def forward(self, xyz, points): | |
| """ | |
| Input: | |
| xyz: input points position data, [B, C, N] | |
| points: input points data, [B, D, N] | |
| Return: | |
| new_points: sample points feature data, [B, d_hidden[-1], n_point] | |
| """ | |
| xyz = xyz.permute(0, 2, 1) | |
| if points is not None: | |
| points = points.permute(0, 2, 1) | |
| if self.group_all: | |
| new_xyz, new_points = sample_and_group_all(xyz, points) | |
| else: | |
| new_xyz, new_points = sample_and_group( | |
| self.n_point, | |
| self.radius, | |
| self.n_sample, | |
| xyz, | |
| points, | |
| deterministic=not self.training, | |
| fps_method=self.fps_method, | |
| ) | |
| # new_xyz: sampled points position data, [B, n_point, C] | |
| # new_points: sampled points data, [B, n_point, n_sample, C+D] | |
| new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point] | |
| for i, conv in enumerate(self.mlp_convs): | |
| new_points = self.act(self.apply_conv(new_points, conv)) | |
| new_points = new_points.mean(dim=2) | |
| return new_points | |
| def apply_conv(self, points: torch.Tensor, conv: nn.Module): | |
| batch, channels, n_samples, _ = points.shape | |
| # Shuffle the representations | |
| if self.patch_size > 1: | |
| # TODO shuffle deterministically when not self.training | |
| _, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2) | |
| points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape)) | |
| return conv(points) | |