Spaces:
Sleeping
Sleeping
import copy | |
import math | |
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
from rff.layers import GaussianEncoding, PositionalEncoding | |
from torch import nn | |
from .kan.fasterkan import FasterKAN | |
class Sine(nn.Module): | |
def __init__(self, w0=1.0): | |
super().__init__() | |
self.w0 = w0 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return torch.sin(self.w0 * x) | |
def params_to_tensor(params): | |
return torch.cat([p.flatten() for p in params]), [p.shape for p in params] | |
def tensor_to_params(tensor, shapes): | |
params = [] | |
start = 0 | |
for shape in shapes: | |
size = torch.prod(torch.tensor(shape)).item() | |
param = tensor[start : start + size].reshape(shape) | |
params.append(param) | |
start += size | |
return tuple(params) | |
def wrap_func(func, shapes): | |
def wrapped_func(params, *args, **kwargs): | |
params = tensor_to_params(params, shapes) | |
return func(params, *args, **kwargs) | |
return wrapped_func | |
class Siren(nn.Module): | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
w0=30.0, | |
c=6.0, | |
is_first=False, | |
use_bias=True, | |
activation=None, | |
): | |
super().__init__() | |
self.w0 = w0 | |
self.c = c | |
self.dim_in = dim_in | |
self.dim_out = dim_out | |
self.is_first = is_first | |
weight = torch.zeros(dim_out, dim_in) | |
bias = torch.zeros(dim_out) if use_bias else None | |
self.init_(weight, bias, c=c, w0=w0) | |
self.weight = nn.Parameter(weight) | |
self.bias = nn.Parameter(bias) if use_bias else None | |
self.activation = Sine(w0) if activation is None else activation | |
def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float): | |
dim = self.dim_in | |
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) | |
weight.uniform_(-w_std, w_std) | |
if bias is not None: | |
# bias.uniform_(-w_std, w_std) | |
bias.zero_() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = F.linear(x, self.weight, self.bias) | |
out = self.activation(out) | |
return out | |
class INR(nn.Module): | |
def __init__( | |
self, | |
in_features: int = 2, | |
n_layers: int = 3, | |
hidden_features: int = 32, | |
out_features: int = 1, | |
pe_features: Optional[int] = None, | |
fix_pe=True, | |
): | |
super().__init__() | |
if pe_features is not None: | |
if fix_pe: | |
self.layers = [PositionalEncoding(sigma=10, m=pe_features)] | |
encoded_dim = in_features * pe_features * 2 | |
else: | |
self.layers = [ | |
GaussianEncoding( | |
sigma=10, input_size=in_features, encoded_size=pe_features | |
) | |
] | |
encoded_dim = pe_features * 2 | |
self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features)) | |
else: | |
self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)] | |
for i in range(n_layers - 2): | |
self.layers.append(Siren(hidden_features, hidden_features)) | |
self.layers.append(nn.Linear(hidden_features, out_features)) | |
self.seq = nn.Sequential(*self.layers) | |
self.num_layers = len(self.layers) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.seq(x) + 0.5 | |
class INRPerLayer(INR): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
nodes = [x] | |
for layer in self.seq: | |
nodes.append(layer(nodes[-1])) | |
nodes[-1] = nodes[-1] + 0.5 | |
return nodes | |
def make_functional(mod, disable_autograd_tracking=False): | |
params_dict = dict(mod.named_parameters()) | |
params_names = params_dict.keys() | |
params_values = tuple(params_dict.values()) | |
stateless_mod = copy.deepcopy(mod) | |
stateless_mod.to("meta") | |
def fmodel(new_params_values, *args, **kwargs): | |
new_params_dict = { | |
name: value for name, value in zip(params_names, new_params_values) | |
} | |
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs) | |
if disable_autograd_tracking: | |
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) | |
return fmodel, params_values | |