import torch
import torch.nn as nn

class Image2DResBlockWithTV(nn.Module):
    def __init__(self, dim, tdim, vdim):
        super().__init__()
        norm = lambda c: nn.GroupNorm(8, c)
        self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
        self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
        self.conv = nn.Sequential(
            norm(dim),
            nn.SiLU(True),
            nn.Conv2d(dim, dim, 3, 1, 1),
            norm(dim),
            nn.SiLU(True),
            nn.Conv2d(dim, dim, 3, 1, 1),
        )

    def forward(self, x, t, v):
        return x+self.conv(x+self.time_embed(t)+self.view_embed(v))


class NoisyTargetViewEncoder(nn.Module):
    def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
        super().__init__()

        self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
        self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
        self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
        self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
        self.final_out = nn.Sequential(
            nn.GroupNorm(8, run_dim),
            nn.SiLU(True),
            nn.Conv2d(run_dim, output_dim, 3, 1, 1)
        )

    def forward(self, x, t, v):
        B, DT = t.shape
        t = t.view(B, DT, 1, 1)
        B, DV = v.shape
        v = v.view(B, DV, 1, 1)

        x = self.init_conv(x)
        x = self.out_conv0(x, t, v)
        x = self.out_conv1(x, t, v)
        x = self.out_conv2(x, t, v)
        x = self.final_out(x)
        return x

class SpatialUpTimeBlock(nn.Module):
    def __init__(self, x_in_dim, t_in_dim, out_dim):
        super().__init__()
        norm_act = lambda c: nn.GroupNorm(8, c)
        self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1)  # 16
        self.norm = norm_act(x_in_dim)
        self.silu = nn.SiLU(True)
        self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)

    def forward(self, x, t):
        x = x + self.t_conv(t)
        return self.conv(self.silu(self.norm(x)))

class SpatialTimeBlock(nn.Module):
    def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
        super().__init__()
        norm_act = lambda c: nn.GroupNorm(8, c)
        self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1)  # 16
        self.bn = norm_act(x_in_dim)
        self.silu = nn.SiLU(True)
        self.conv = nn.Conv3d(x_in_dim, out_dim, 3, stride=stride, padding=1)

    def forward(self, x, t):
        x = x + self.t_conv(t)
        return self.conv(self.silu(self.bn(x)))

class SpatialTime3DNet(nn.Module):
        def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
            super().__init__()
            d0, d1, d2, d3 = dims
            dt = time_dim

            self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1)  # 32
            self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)

            self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
            self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
            self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)

            self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
            self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
            self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)

            self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
            self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
            self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)

            self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
            self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
            self.conv9 = SpatialUpTimeBlock(d1, dt, d0)

        def forward(self, x, t):
            B, C = t.shape
            t = t.view(B, C, 1, 1, 1)

            x = self.init_conv(x)
            conv0 = self.conv0(x, t)

            x = self.conv1(conv0, t)
            x = self.conv2_0(x, t)
            conv2 = self.conv2_1(x, t)

            x = self.conv3(conv2, t)
            x = self.conv4_0(x, t)
            conv4 = self.conv4_1(x, t)

            x = self.conv5(conv4, t)
            x = self.conv6_0(x, t)
            x = self.conv6_1(x, t)

            x = conv4 + self.conv7(x, t)
            x = conv2 + self.conv8(x, t)
            x = conv0 + self.conv9(x, t)
            return x

class FrustumTVBlock(nn.Module):
    def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
        super().__init__()
        norm_act = lambda c: nn.GroupNorm(8, c)
        self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
        self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
        self.bn = norm_act(x_dim)
        self.silu = nn.SiLU(True)
        self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)

    def forward(self, x, t, v):
        x = x + self.t_conv(t) + self.v_conv(v)
        return self.conv(self.silu(self.bn(x)))

class FrustumTVUpBlock(nn.Module):
    def __init__(self, x_dim, t_dim, v_dim, out_dim):
        super().__init__()
        norm_act = lambda c: nn.GroupNorm(8, c)
        self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
        self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
        self.norm = norm_act(x_dim)
        self.silu = nn.SiLU(True)
        self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)

    def forward(self, x, t, v):
        x = x + self.t_conv(t) + self.v_conv(v)
        return self.conv(self.silu(self.norm(x)))

class FrustumTV3DNet(nn.Module):
    def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
        super().__init__()
        self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32

        self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
        self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)

        self.conv3 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[2], 2)
        self.conv4 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[2], 1)

        self.conv5 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[3], 2)
        self.conv6 = FrustumTVBlock(dims[3], t_dim, v_dim, dims[3], 1)

        self.up0 = FrustumTVUpBlock(dims[3], t_dim, v_dim, dims[2])
        self.up1 = FrustumTVUpBlock(dims[2], t_dim, v_dim, dims[1])
        self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])

    def forward(self, x, t, v):
        B,DT = t.shape
        t = t.view(B,DT,1,1,1)
        B,DV = v.shape
        v = v.view(B,DV,1,1,1)

        b, _, d, h, w = x.shape
        x0 = self.conv0(x)
        x1 = self.conv2(self.conv1(x0, t, v), t, v)
        x2 = self.conv4(self.conv3(x1, t, v), t, v)
        x3 = self.conv6(self.conv5(x2, t, v), t, v)

        x2 = self.up0(x3, t, v) + x2
        x1 = self.up1(x2, t, v) + x1
        x0 = self.up2(x1, t, v) + x0
        return {w: x0, w//2: x1, w//4: x2, w//8: x3}