IlayMalinyak
cnnkan
2f54ec8
import copy
import math
from typing import Optional
import torch
import torch.nn.functional as F
from rff.layers import GaussianEncoding, PositionalEncoding
from torch import nn
from .kan.fasterkan import FasterKAN
class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(self.w0 * x)
def params_to_tensor(params):
return torch.cat([p.flatten() for p in params]), [p.shape for p in params]
def tensor_to_params(tensor, shapes):
params = []
start = 0
for shape in shapes:
size = torch.prod(torch.tensor(shape)).item()
param = tensor[start : start + size].reshape(shape)
params.append(param)
start += size
return tuple(params)
def wrap_func(func, shapes):
def wrapped_func(params, *args, **kwargs):
params = tensor_to_params(params, shapes)
return func(params, *args, **kwargs)
return wrapped_func
class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0=30.0,
c=6.0,
is_first=False,
use_bias=True,
activation=None,
):
super().__init__()
self.w0 = w0
self.c = c
self.dim_in = dim_in
self.dim_out = dim_out
self.is_first = is_first
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c=c, w0=w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if bias is not None:
# bias.uniform_(-w_std, w_std)
bias.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
return out
class INR(nn.Module):
def __init__(
self,
in_features: int = 2,
n_layers: int = 3,
hidden_features: int = 32,
out_features: int = 1,
pe_features: Optional[int] = None,
fix_pe=True,
):
super().__init__()
if pe_features is not None:
if fix_pe:
self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
encoded_dim = in_features * pe_features * 2
else:
self.layers = [
GaussianEncoding(
sigma=10, input_size=in_features, encoded_size=pe_features
)
]
encoded_dim = pe_features * 2
self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
else:
self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
for i in range(n_layers - 2):
self.layers.append(Siren(hidden_features, hidden_features))
self.layers.append(nn.Linear(hidden_features, out_features))
self.seq = nn.Sequential(*self.layers)
self.num_layers = len(self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq(x) + 0.5
class INRPerLayer(INR):
def forward(self, x: torch.Tensor) -> torch.Tensor:
nodes = [x]
for layer in self.seq:
nodes.append(layer(nodes[-1]))
nodes[-1] = nodes[-1] + 0.5
return nodes
def make_functional(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to("meta")
def fmodel(new_params_values, *args, **kwargs):
new_params_dict = {
name: value for name, value in zip(params_names, new_params_values)
}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values