Spaces:
Runtime error
Runtime error
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import tinycudann as tcnn | |
| # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. | |
| class Embedder: | |
| def __init__(self, **kwargs): | |
| self.kwargs = kwargs | |
| self.create_embedding_fn() | |
| def create_embedding_fn(self): | |
| embed_fns = [] | |
| d = self.kwargs['input_dims'] | |
| out_dim = 0 | |
| if self.kwargs['include_input']: | |
| embed_fns.append(lambda x: x) | |
| out_dim += d | |
| max_freq = self.kwargs['max_freq_log2'] | |
| N_freqs = self.kwargs['num_freqs'] | |
| if self.kwargs['log_sampling']: | |
| freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) | |
| else: | |
| freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) | |
| for freq in freq_bands: | |
| for p_fn in self.kwargs['periodic_fns']: | |
| embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) | |
| out_dim += d | |
| self.embed_fns = embed_fns | |
| self.out_dim = out_dim | |
| def embed(self, inputs): | |
| return torch.cat([fn(inputs) for fn in self.embed_fns], -1) | |
| def get_embedder(multires, input_dims=3): | |
| embed_kwargs = { | |
| 'include_input': True, | |
| 'input_dims': input_dims, | |
| 'max_freq_log2': multires - 1, | |
| 'num_freqs': multires, | |
| 'log_sampling': True, | |
| 'periodic_fns': [torch.sin, torch.cos], | |
| } | |
| embedder_obj = Embedder(**embed_kwargs) | |
| def embed(x, eo=embedder_obj): return eo.embed(x) | |
| return embed, embedder_obj.out_dim | |
| class SDFNetwork(nn.Module): | |
| def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, | |
| scale=1, geometric_init=True, weight_norm=True, inside_outside=False): | |
| super(SDFNetwork, self).__init__() | |
| dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] | |
| self.embed_fn_fine = None | |
| if multires > 0: | |
| embed_fn, input_ch = get_embedder(multires, input_dims=d_in) | |
| self.embed_fn_fine = embed_fn | |
| dims[0] = input_ch | |
| self.num_layers = len(dims) | |
| self.skip_in = skip_in | |
| self.scale = scale | |
| for l in range(0, self.num_layers - 1): | |
| if l + 1 in self.skip_in: | |
| out_dim = dims[l + 1] - dims[0] | |
| else: | |
| out_dim = dims[l + 1] | |
| lin = nn.Linear(dims[l], out_dim) | |
| if geometric_init: | |
| if l == self.num_layers - 2: | |
| if not inside_outside: | |
| torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
| torch.nn.init.constant_(lin.bias, -bias) | |
| else: | |
| torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
| torch.nn.init.constant_(lin.bias, bias) | |
| elif multires > 0 and l == 0: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | |
| torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| elif multires > 0 and l in self.skip_in: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) | |
| else: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| if weight_norm: | |
| lin = nn.utils.weight_norm(lin) | |
| setattr(self, "lin" + str(l), lin) | |
| self.activation = nn.Softplus(beta=100) | |
| def forward(self, inputs): | |
| inputs = inputs * self.scale | |
| if self.embed_fn_fine is not None: | |
| inputs = self.embed_fn_fine(inputs) | |
| x = inputs | |
| for l in range(0, self.num_layers - 1): | |
| lin = getattr(self, "lin" + str(l)) | |
| if l in self.skip_in: | |
| x = torch.cat([x, inputs], -1) / np.sqrt(2) | |
| x = lin(x) | |
| if l < self.num_layers - 2: | |
| x = self.activation(x) | |
| return x | |
| def sdf(self, x): | |
| return self.forward(x)[..., :1] | |
| def sdf_hidden_appearance(self, x): | |
| return self.forward(x) | |
| def gradient(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return gradients | |
| def sdf_normal(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return y[..., :1].detach(), gradients.detach() | |
| class SDFNetworkWithFeature(nn.Module): | |
| def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, | |
| scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5): | |
| super().__init__() | |
| self.register_buffer("cube", cube) | |
| self.cube_length = cube_length | |
| dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out] | |
| self.embed_fn_fine = None | |
| if multires > 0: | |
| embed_fn, input_ch = get_embedder(multires, input_dims=dp_in) | |
| self.embed_fn_fine = embed_fn | |
| dims[0] = input_ch + df_in | |
| self.num_layers = len(dims) | |
| self.skip_in = skip_in | |
| self.scale = scale | |
| for l in range(0, self.num_layers - 1): | |
| if l + 1 in self.skip_in: | |
| out_dim = dims[l + 1] - dims[0] | |
| else: | |
| out_dim = dims[l + 1] | |
| lin = nn.Linear(dims[l], out_dim) | |
| if geometric_init: | |
| if l == self.num_layers - 2: | |
| if not inside_outside: | |
| torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
| torch.nn.init.constant_(lin.bias, -bias) | |
| else: | |
| torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
| torch.nn.init.constant_(lin.bias, bias) | |
| elif multires > 0 and l == 0: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | |
| torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| elif multires > 0 and l in self.skip_in: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) | |
| else: | |
| torch.nn.init.constant_(lin.bias, 0.0) | |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
| if weight_norm: | |
| lin = nn.utils.weight_norm(lin) | |
| setattr(self, "lin" + str(l), lin) | |
| self.activation = nn.Softplus(beta=100) | |
| def forward(self, points): | |
| points = points * self.scale | |
| # note: point*2 because the cube is [-0.5,0.5] | |
| with torch.no_grad(): | |
| feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach() | |
| feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1) | |
| if self.embed_fn_fine is not None: | |
| points = self.embed_fn_fine(points) | |
| x = torch.cat([points, feats], -1) | |
| for l in range(0, self.num_layers - 1): | |
| lin = getattr(self, "lin" + str(l)) | |
| if l in self.skip_in: | |
| x = torch.cat([x, points, feats], -1) / np.sqrt(2) | |
| x = lin(x) | |
| if l < self.num_layers - 2: | |
| x = self.activation(x) | |
| # concat feats | |
| x = torch.cat([x, feats], -1) | |
| return x | |
| def sdf(self, x): | |
| return self.forward(x)[..., :1] | |
| def sdf_hidden_appearance(self, x): | |
| return self.forward(x) | |
| def gradient(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return gradients | |
| def sdf_normal(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return y[..., :1].detach(), gradients.detach() | |
| class VanillaMLP(nn.Module): | |
| def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers): | |
| super().__init__() | |
| self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers | |
| self.sphere_init, self.weight_norm = True, True | |
| self.sphere_init_radius = 0.5 | |
| self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] | |
| for i in range(self.n_hidden_layers - 1): | |
| self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] | |
| self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] | |
| self.layers = nn.Sequential(*self.layers) | |
| def forward(self, x): | |
| x = self.layers(x.float()) | |
| return x | |
| def make_linear(self, dim_in, dim_out, is_first, is_last): | |
| layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality | |
| if self.sphere_init: | |
| if is_last: | |
| torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) | |
| torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) | |
| elif is_first: | |
| torch.nn.init.constant_(layer.bias, 0.0) | |
| torch.nn.init.constant_(layer.weight[:, 3:], 0.0) | |
| torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) | |
| else: | |
| torch.nn.init.constant_(layer.bias, 0.0) | |
| torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) | |
| else: | |
| torch.nn.init.constant_(layer.bias, 0.0) | |
| torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') | |
| if self.weight_norm: | |
| layer = nn.utils.weight_norm(layer) | |
| return layer | |
| def make_activation(self): | |
| if self.sphere_init: | |
| return nn.Softplus(beta=100) | |
| else: | |
| return nn.ReLU(inplace=True) | |
| class SDFHashGridNetwork(nn.Module): | |
| def __init__(self, bound=0.5, feats_dim=13): | |
| super().__init__() | |
| self.bound = bound | |
| # max_resolution = 32 | |
| # base_resolution = 16 | |
| # n_levels = 4 | |
| # log2_hashmap_size = 16 | |
| # n_features_per_level = 8 | |
| max_resolution = 2048 | |
| base_resolution = 16 | |
| n_levels = 16 | |
| log2_hashmap_size = 19 | |
| n_features_per_level = 2 | |
| # max_res = base_res * t^(k-1) | |
| per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1)) | |
| self.encoder = tcnn.Encoding( | |
| n_input_dims=3, | |
| encoding_config={ | |
| "otype": "HashGrid", | |
| "n_levels": n_levels, | |
| "n_features_per_level": n_features_per_level, | |
| "log2_hashmap_size": log2_hashmap_size, | |
| "base_resolution": base_resolution, | |
| "per_level_scale": per_level_scale, | |
| }, | |
| ) | |
| self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1) | |
| def forward(self, x): | |
| shape = x.shape[:-1] | |
| x = x.reshape(-1, 3) | |
| x_ = (x + self.bound) / (2 * self.bound) | |
| feats = self.encoder(x_) | |
| feats = torch.cat([x, feats], 1) | |
| feats = self.sdf_mlp(feats) | |
| feats = feats.reshape(*shape,-1) | |
| return feats | |
| def sdf(self, x): | |
| return self(x)[...,:1] | |
| def gradient(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return gradients | |
| def sdf_normal(self, x): | |
| x.requires_grad_(True) | |
| with torch.enable_grad(): | |
| y = self.sdf(x) | |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
| gradients = torch.autograd.grad( | |
| outputs=y, | |
| inputs=x, | |
| grad_outputs=d_output, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True)[0] | |
| return y[..., :1].detach(), gradients.detach() | |
| class RenderingFFNetwork(nn.Module): | |
| def __init__(self, in_feats_dim=12): | |
| super().__init__() | |
| self.dir_encoder = tcnn.Encoding( | |
| n_input_dims=3, | |
| encoding_config={ | |
| "otype": "SphericalHarmonics", | |
| "degree": 4, | |
| }, | |
| ) | |
| self.color_mlp = tcnn.Network( | |
| n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims, | |
| n_output_dims = 3, | |
| network_config={ | |
| "otype": "FullyFusedMLP", | |
| "activation": "ReLU", | |
| "output_activation": "none", | |
| "n_neurons": 64, | |
| "n_hidden_layers": 2, | |
| }, | |
| ) | |
| def forward(self, points, normals, view_dirs, feature_vectors): | |
| normals = F.normalize(normals, dim=-1) | |
| view_dirs = F.normalize(view_dirs, dim=-1) | |
| reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs | |
| x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1) | |
| colors = self.color_mlp(x).float() | |
| colors = F.sigmoid(colors) | |
| return colors | |
| # This implementation is borrowed from IDR: https://github.com/lioryariv/idr | |
| class RenderingNetwork(nn.Module): | |
| def __init__(self, d_feature, d_in, d_out, d_hidden, | |
| n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True): | |
| super().__init__() | |
| self.squeeze_out = squeeze_out | |
| self.rgb_act=F.sigmoid | |
| self.use_view_dir=use_view_dir | |
| dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] | |
| self.embedview_fn = None | |
| if multires_view > 0: | |
| embedview_fn, input_ch = get_embedder(multires_view) | |
| self.embedview_fn = embedview_fn | |
| dims[0] += (input_ch - 3) | |
| self.num_layers = len(dims) | |
| for l in range(0, self.num_layers - 1): | |
| out_dim = dims[l + 1] | |
| lin = nn.Linear(dims[l], out_dim) | |
| if weight_norm: | |
| lin = nn.utils.weight_norm(lin) | |
| setattr(self, "lin" + str(l), lin) | |
| self.relu = nn.ReLU() | |
| def forward(self, points, normals, view_dirs, feature_vectors): | |
| if self.use_view_dir: | |
| view_dirs = F.normalize(view_dirs, dim=-1) | |
| normals = F.normalize(normals, dim=-1) | |
| reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs | |
| if self.embedview_fn is not None: reflective = self.embedview_fn(reflective) | |
| rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1) | |
| else: | |
| rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) | |
| x = rendering_input | |
| for l in range(0, self.num_layers - 1): | |
| lin = getattr(self, "lin" + str(l)) | |
| x = lin(x) | |
| if l < self.num_layers - 2: | |
| x = self.relu(x) | |
| if self.squeeze_out: | |
| x = self.rgb_act(x) | |
| return x | |
| class SingleVarianceNetwork(nn.Module): | |
| def __init__(self, init_val, activation='exp'): | |
| super(SingleVarianceNetwork, self).__init__() | |
| self.act = activation | |
| self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) | |
| def forward(self, x): | |
| device = x.device | |
| if self.act=='exp': | |
| return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0) | |
| else: | |
| raise NotImplementedError | |
| def warp(self, x, inv_s): | |
| device = x.device | |
| return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s |