from dataclasses import dataclass, field
from typing import Callable, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
from torch.amp import custom_bwd, custom_fwd
from torch.autograd import Function

from spar3d.models.utils import BaseModule, normalize
from spar3d.utils import get_device


def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
    def wrapper(fn):
        if condition:
            if len(kwargs) == 0:
                return decorator_with_args
            return decorator_with_args(*args, **kwargs)(fn)
        else:
            return fn

    return wrapper


class PixelShuffleUpsampleNetwork(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        in_channels: int = 1024
        out_channels: int = 40
        scale_factor: int = 4

        conv_layers: int = 4
        conv_kernel_size: int = 3

    cfg: Config

    def configure(self) -> None:
        layers = []
        output_channels = self.cfg.out_channels * self.cfg.scale_factor**2

        in_channels = self.cfg.in_channels
        for i in range(self.cfg.conv_layers):
            cur_out_channels = (
                in_channels if i != self.cfg.conv_layers - 1 else output_channels
            )
            layers.append(
                nn.Conv2d(
                    in_channels,
                    cur_out_channels,
                    self.cfg.conv_kernel_size,
                    padding=(self.cfg.conv_kernel_size - 1) // 2,
                )
            )
            if i != self.cfg.conv_layers - 1:
                layers.append(nn.ReLU(inplace=True))

        layers.append(nn.PixelShuffle(self.cfg.scale_factor))

        self.upsample = nn.Sequential(*layers)

    def forward(
        self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
    ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
        return rearrange(
            self.upsample(
                rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
            ),
            "(B Np) Co Hp Wp -> B Np Co Hp Wp",
            Np=3,
        )


class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @conditional_decorator(
        custom_fwd,
        "cuda" in get_device(),
        cast_inputs=torch.float32,
        device_type="cuda",
    )
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @conditional_decorator(custom_bwd, "cuda" in get_device())
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))


trunc_exp = _TruncExp.apply


def get_activation(name) -> Callable:
    if name is None:
        return lambda x: x
    name = name.lower()
    if name == "none" or name == "linear" or name == "identity":
        return lambda x: x
    elif name == "lin2srgb":
        return lambda x: torch.where(
            x > 0.0031308,
            torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
            12.92 * x,
        ).clamp(0.0, 1.0)
    elif name == "exp":
        return lambda x: torch.exp(x)
    elif name == "shifted_exp":
        return lambda x: torch.exp(x - 1.0)
    elif name == "trunc_exp":
        return trunc_exp
    elif name == "shifted_trunc_exp":
        return lambda x: trunc_exp(x - 1.0)
    elif name == "sigmoid":
        return lambda x: torch.sigmoid(x)
    elif name == "tanh":
        return lambda x: torch.tanh(x)
    elif name == "shifted_softplus":
        return lambda x: F.softplus(x - 1.0)
    elif name == "scale_-11_01":
        return lambda x: x * 0.5 + 0.5
    elif name == "negative":
        return lambda x: -x
    elif name == "normalize_channel_last":
        return lambda x: normalize(x)
    elif name == "normalize_channel_first":
        return lambda x: normalize(x, dim=1)
    else:
        try:
            return getattr(F, name)
        except AttributeError:
            raise ValueError(f"Unknown activation function: {name}")


class LambdaModule(torch.nn.Module):
    def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


def get_activation_module(name) -> torch.nn.Module:
    return LambdaModule(get_activation(name))


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    out_bias: float = 0.0


class MaterialMLP(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        in_channels: int = 120
        n_neurons: int = 64
        activation: str = "silu"
        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self) -> None:
        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []
            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
                        self.cfg.n_neurons,
                    ),
                    self.make_activation(self.cfg.activation),
                ]
            head_layers += [
                nn.Linear(
                    self.cfg.n_neurons,
                    head.out_channels,
                ),
            ]
            heads[head.name] = nn.Sequential(*head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def keys(self):
        return self.heads.keys()

    def forward(
        self, x, include: Optional[List] = None, exclude: Optional[List] = None
    ):
        if include is not None and exclude is not None:
            raise ValueError("Cannot specify both include and exclude.")
        if include is not None:
            heads = [h for h in self.cfg.heads if h.name in include]
        elif exclude is not None:
            heads = [h for h in self.cfg.heads if h.name not in exclude]
        else:
            heads = self.cfg.heads

        out = {
            head.name: get_activation(head.output_activation)(
                self.heads[head.name](x) + head.out_bias
            )
            for head in heads
        }

        return out