import copy
import math
from typing import Optional

import torch
import torch.nn.functional as F
from rff.layers import GaussianEncoding, PositionalEncoding
from torch import nn
from .kan.fasterkan import FasterKAN



class Sine(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sin(self.w0 * x)


def params_to_tensor(params):
    return torch.cat([p.flatten() for p in params]), [p.shape for p in params]


def tensor_to_params(tensor, shapes):
    params = []
    start = 0
    for shape in shapes:
        size = torch.prod(torch.tensor(shape)).item()
        param = tensor[start : start + size].reshape(shape)
        params.append(param)
        start += size
    return tuple(params)


def wrap_func(func, shapes):
    def wrapped_func(params, *args, **kwargs):
        params = tensor_to_params(params, shapes)
        return func(params, *args, **kwargs)

    return wrapped_func


class Siren(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        w0=30.0,
        c=6.0,
        is_first=False,
        use_bias=True,
        activation=None,
    ):
        super().__init__()
        self.w0 = w0
        self.c = c
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.is_first = is_first

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

    def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if bias is not None:
            # bias.uniform_(-w_std, w_std)
            bias.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        return out


class INR(nn.Module):
    def __init__(
        self,
        in_features: int = 2,
        n_layers: int = 3,
        hidden_features: int = 32,
        out_features: int = 1,
        pe_features: Optional[int] = None,
        fix_pe=True,
    ):
        super().__init__()

        if pe_features is not None:
            if fix_pe:
                self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
                encoded_dim = in_features * pe_features * 2
            else:
                self.layers = [
                    GaussianEncoding(
                        sigma=10, input_size=in_features, encoded_size=pe_features
                    )
                ]
                encoded_dim = pe_features * 2
            self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
        else:
            self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
        for i in range(n_layers - 2):
            self.layers.append(Siren(hidden_features, hidden_features))
        self.layers.append(nn.Linear(hidden_features, out_features))
        self.seq = nn.Sequential(*self.layers)
        self.num_layers = len(self.layers)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seq(x) + 0.5


class INRPerLayer(INR):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        nodes = [x]
        for layer in self.seq:
            nodes.append(layer(nodes[-1]))
        nodes[-1] = nodes[-1] + 0.5
        return nodes


def make_functional(mod, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to("meta")

    def fmodel(new_params_values, *args, **kwargs):
        new_params_dict = {
            name: value for name, value in zip(params_names, new_params_values)
        }
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values