from .tensorBase import *
import torch.nn as nn
import itertools

class Density(nn.Module):
    def __init__(self, params_init={}):
        super().__init__()
        for p in params_init:
            param = nn.Parameter(torch.tensor(params_init[p]))
            setattr(self, p, param)

        # self.beta0=0.1
        # self.beta1=0.001
        # self.beta=self.beta0

    def forward(self, sdf, beta=None):
        return self.density_func(sdf, beta=beta)


class LaplaceDensity(Density):  # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
    #params_init{ beta = 0.1 } beta_min = 0.0001
    def __init__(self, params_init={}, beta_min=0.0001):
        super().__init__(params_init=params_init)
        self.beta_min = torch.tensor(beta_min).cuda()

    def density_func(self, sdf, beta=None):
        if beta is None:
            beta = self.get_beta()

        alpha = 1 / beta
        return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))

    def get_beta(self):
        beta = self.beta.abs() + self.beta_min
        return self.beta
    
    # t for 0-1
    def set_beta(self,t):
        
        self.beta = self.beta0 * (1 + ((self.beta0 - self.beta1) / self.beta1) * (t**0.8)) ** -1
        return self.beta

   
    
class TensorVMSplit_Mesh(TensorBase):
    def __init__(self, aabb, gridSize, **kargs):
        super(TensorVMSplit_Mesh, self).__init__(aabb, gridSize, **kargs)
        
        hidden_dim = 64
        num_layers = 5
        activation = nn.ReLU
        
        n_comp=self.density_n_comp+self.app_n_comp
            
        self.decoder = nn.Sequential(
            nn.Linear(n_comp*3, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 7),
        )
        
        # self.net_sdf = nn.Sequential(
        #         nn.Linear(n_comp*3, hidden_dim),
        #         activation(),
        #         *itertools.chain(*[[
        #             nn.Linear(hidden_dim, hidden_dim),
        #             activation(),
        #         ] for _ in range(num_layers - 2)]),
        #         nn.Linear(hidden_dim, 1),
        #     )
        
        hidden_dim_min = 64
        num_layers_min = 2
            
        self.net_deformation = nn.Sequential(
                nn.Linear(n_comp*3, hidden_dim_min),
                activation(),
                *itertools.chain(*[[
                    nn.Linear(hidden_dim_min, hidden_dim_min),
                    activation(),
                ] for _ in range(num_layers_min - 2)]),
                nn.Linear(hidden_dim_min, 3),
            )
        
        self.net_weight = nn.Sequential(
            nn.Linear(n_comp*3*8, hidden_dim_min),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim_min, hidden_dim_min),
                activation(),
            ] for _ in range(num_layers_min - 2)]),
            nn.Linear(hidden_dim_min, 21),
        )
        
         # init all bias to zero
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)
                
    def init_render_func(self,shadingMode, pos_pe, view_pe, fea_pe, featureC):
        pass
    
    
    def compute_densityfeature(self, xyz_sampled):
        
        B,  N_point, _=xyz_sampled.shape

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(3):

            density_plane=self.density_plane[:,idx_plane]#.contiguous()
            density_line=self.density_line[:,idx_plane]#.contiguous()

            plane_coef_point.append(F.grid_sample(density_plane, coordinate_plane[idx_plane],
                                                align_corners=True).view(B, -1, N_point))
            line_coef_point.append(F.grid_sample(density_line, coordinate_line[idx_plane],
                                            align_corners=True).view(B, -1, N_point))
            
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1)
        plane_coef=plane_coef_point * line_coef_point
        plane_coef=plane_coef.permute(0,2,1)

        result = torch.matmul(plane_coef, self.d_basis_mat)
            
        return result


    def compute_appfeature(self, xyz_sampled):

        B, N_point, _=xyz_sampled.shape
        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(3):
            
            app_plane=self.app_plane[:,idx_plane]
            app_line=self.app_line[:,idx_plane]

            plane_coef_point.append(F.grid_sample(app_plane, coordinate_plane[idx_plane],
                                                align_corners=True).view(B, -1, N_point))
            line_coef_point.append(F.grid_sample(app_line, coordinate_line[idx_plane],
                                            align_corners=True).view(B, -1, N_point))
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1)
        plane_coef=plane_coef_point * line_coef_point
        plane_coef=plane_coef.permute(0,2,1)

        # result = torch.matmul(plane_coef, self.basis_mat)

        return plane_coef


    def geometry_feature_decode(self, sampled_features, flexicubes_indices):

        sdf = self.decoder(sampled_features)[...,-1:]
        deformation = self.net_deformation(sampled_features)

        grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
        grid_features = grid_features.reshape(
            sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
        weight = self.net_weight(grid_features) * 0.1

        return sdf, deformation, weight


    def get_geometry_prediction(self, svd_volume, sample_coordinates, flexicubes_indices):
        
        self.svd_volume=svd_volume
        self.app_plane=svd_volume['app_planes']
        self.app_line=svd_volume['app_lines']
        self.basis_mat=svd_volume['basis_mat']
        self.density_plane=svd_volume['density_planes']
        self.density_line=svd_volume['density_lines']
        self.d_basis_mat=svd_volume['d_basis_mat']
        
        self.app_plane=torch.cat([self.app_plane,self.density_plane],2)
        self.app_line=torch.cat([self.app_line,self.density_line],2)
        
        sampled_features = self.compute_appfeature(sample_coordinates)
        
        sdf, deformation, weight = self.geometry_feature_decode(sampled_features, flexicubes_indices)
        
        return sdf, deformation, weight
        
    def get_texture_prediction(self,texture_pos, vsd_vome=None):\
        
        app_features = self.compute_appfeature(texture_pos)
        
        texture_rgb=self.decoder(app_features)[...,0:-1]
        
        texture_rgb = torch.sigmoid(texture_rgb)*(1 + 2*0.001) - 0.001
        
        return texture_rgb
    
    
    
    def predict_color(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):

        self.svd_volume=svd_volume
        self.app_plane=svd_volume['app_planes']
        self.app_line=svd_volume['app_lines']
        self.basis_mat=svd_volume['basis_mat']
        self.d_basis_mat=svd_volume['d_basis_mat']
        self.density_plane=svd_volume['density_planes']
        self.density_line=svd_volume['density_lines']
        
        self.app_plane=torch.cat([self.app_plane,self.density_plane],2)
        self.app_line=torch.cat([self.app_line,self.density_line],2)
        
        #xyz_sampled=xyz_sampled.unsqueeze(2)

        chunk_size: int = 2**20
        outs = []
        for i in range(0, xyz_sampled.shape[2], chunk_size):
            xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size])
            #xyz_sampled.requires_grad_(True)
            
            app_features = self.compute_appfeature(xyz_sampled_chunk)
            
            chunk_out = self.decoder(app_features)[...,0:-1]
            
            chunk_out = torch.sigmoid(chunk_out)*(1 + 2*0.001) - 0.001 
            
            rgbs = chunk_out.clamp(0,1)
            outs.append(chunk_out)
            
        rgbs=torch.cat(outs,1)
        
        albedo=rgbs[:,:,3:6]
        rgb=rgbs[:,:,0:3]
    
        results = {
            'shading':rgb,
            'albedo':albedo,
            'rgb':rgb*albedo,
        }
        return results # rgb, sigma, alpha, weight, bg_weight
        



# special nerf for mesh
class TensorVMSplit_NeRF(TensorBase):
    def __init__(self, aabb, gridSize, **kargs):
        super(TensorVMSplit_NeRF, self).__init__(aabb, gridSize, **kargs)
        
        hidden_dim = 64
        num_layers = 4
        activation = nn.ReLU
        
        self.lap_density = LaplaceDensity(params_init={ 'beta' : 0.1})

        n_comp=self.density_n_comp+self.app_n_comp
        
        self.net_sdf = nn.Sequential(
            nn.Linear(n_comp*3, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 1),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(n_comp*3, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 6),
        )
        
         # init all bias to zero
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)
                
    def init_render_func(self,shadingMode, pos_pe, view_pe, fea_pe, featureC):
        pass
    
    
    def compute_densityfeature(self, xyz_sampled):
        
        B, N_pixel, N_sample, _=xyz_sampled.shape

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(3):

            density_plane=self.density_plane[:,idx_plane]#.contiguous()
            density_line=self.density_line[:,idx_plane]#.contiguous()

            plane_coef_point.append(F.grid_sample(density_plane, coordinate_plane[idx_plane],
                                                align_corners=True).view(B, -1, N_pixel, N_sample))
            line_coef_point.append(F.grid_sample(density_line, coordinate_line[idx_plane],
                                            align_corners=True).view(B, -1, N_pixel, N_sample))
            
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1)
        plane_coef=plane_coef_point * line_coef_point
        plane_coef=plane_coef.permute(0,2,3,1)

        result = torch.matmul(plane_coef, self.d_basis_mat.unsqueeze(1))
            
        return result


    def compute_appfeature(self, xyz_sampled):

        B, N_pixel, N_sample, _=xyz_sampled.shape
        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(3):
            
            app_plane=self.app_plane[:,idx_plane]
            app_line=self.app_line[:,idx_plane]

            plane_coef_point.append(F.grid_sample(app_plane, coordinate_plane[idx_plane],
                                                align_corners=True).view(B, -1, N_pixel, N_sample))
            line_coef_point.append(F.grid_sample(app_line, coordinate_line[idx_plane],
                                            align_corners=True).view(B, -1, N_pixel, N_sample))
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1)
        plane_coef=plane_coef_point * line_coef_point
        plane_coef=plane_coef.permute(0,2,3,1)

        return plane_coef
    
    def forward(self, svd_volume, rays_o, rays_d, bg_color, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
        self.svd_volume=svd_volume
        self.app_plane=svd_volume['app_planes']
        self.app_line=svd_volume['app_lines']
        self.basis_mat=svd_volume['basis_mat']
        self.d_basis_mat=svd_volume['d_basis_mat']
        self.density_plane=svd_volume['density_planes']
        self.density_line=svd_volume['density_lines']
        
        self.app_plane=torch.cat([self.app_plane,self.density_plane],2)
        self.app_line=torch.cat([self.app_line,self.density_line],2)

        B,V,H,W,_=rays_o.shape
        rays_o=rays_o.reshape(B,-1, 3)
        rays_d=rays_d.reshape(B,-1, 3)
        if ndc_ray:
            pass
        else:
            #B,H*W*V,sample_num,3
            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_o, rays_d, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[..., 1:] - z_vals[..., :-1], torch.zeros_like(z_vals[..., :1])), dim=-1)
        rays_d = rays_d.unsqueeze(-2).expand(xyz_sampled.shape)

        xyz_sampled = self.normalize_coord(xyz_sampled)
        
        mix_feature = self.compute_appfeature(xyz_sampled)
        
        sdf = self.net_sdf(mix_feature)
        
        sigma= self.lap_density(sdf)
        sigma=sigma[...,0]  
        alpha, weight, bg_weight = raw2alpha(sigma, dists)

        rgbs = self.decoder(mix_feature)
        rgbs = torch.sigmoid(rgbs)*(1 + 2*0.001) - 0.001
        #rgb[app_mask] = valid_rgbs

        acc_map = torch.sum(weight, -1)
        rgb_map = torch.sum(weight[..., None] * rgbs, -2)

        if white_bg or (is_train and torch.rand((1,))<0.5):
            rgb_map = rgb_map + (1. - acc_map[..., None])

        
        rgb_map = rgb_map.clamp(0,1)
        rgb_map=rgb_map.view(B,V,H,W,6).permute(0,1,4,2,3)
        
        albedo_map=rgb_map[:,:,3:6,:,:]
        rgb_map=rgb_map[:,:,0:3,:,:]

        with torch.no_grad():
            depth_map = torch.sum(weight * z_vals, -1)
        depth_map=depth_map.view(B,V,H,W,1).permute(0,1,4,2,3)
        acc_map=acc_map.view(B,V,H,W,1).permute(0,1,4,2,3)

        results = {
            'image':rgb_map,
            'albedo':albedo_map,
            'alpha':acc_map,
            'depth_map':depth_map
        }

        return results # rgb, sigma, alpha, weight, bg_weight
    
    
    def predict_sdf(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):

        self.svd_volume=svd_volume
        self.app_plane=svd_volume['app_planes']
        self.app_line=svd_volume['app_lines']
        self.basis_mat=svd_volume['basis_mat']
        self.d_basis_mat=svd_volume['d_basis_mat']
        self.density_plane=svd_volume['density_planes']
        self.density_line=svd_volume['density_lines']

        self.app_plane=torch.cat([self.app_plane,self.density_plane],2)
        self.app_line=torch.cat([self.app_line,self.density_line],2)
        
        chunk_size: int = 2**20
        outs = []
        for i in range(0, xyz_sampled.shape[1], chunk_size):
            xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size]).half()
            
            sigma_feature = self.compute_appfeature(xyz_sampled_chunk)
            chunk_out = self.net_sdf(sigma_feature)
            
            outs.append(chunk_out)
        sdf=torch.cat(outs,1)
        results = {
            'sigma':sdf
        }
        return results # rgb, sigma, alpha, weight, bg_weight
    
    
    def predict_color(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):

        self.svd_volume=svd_volume
        self.app_plane=svd_volume['app_planes']
        self.app_line=svd_volume['app_lines']
        self.basis_mat=svd_volume['basis_mat']
        self.d_basis_mat=svd_volume['d_basis_mat']
        self.density_plane=svd_volume['density_planes']
        self.density_line=svd_volume['density_lines']
        
        self.app_plane=torch.cat([self.app_plane,self.density_plane],2)
        self.app_line=torch.cat([self.app_line,self.density_line],2)
        
        xyz_sampled=xyz_sampled.unsqueeze(2)

        chunk_size: int = 2**20
        outs = []
        for i in range(0, xyz_sampled.shape[2], chunk_size):
            xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size]).half()
            #xyz_sampled.requires_grad_(True)
            
            app_features = self.compute_appfeature(xyz_sampled_chunk)
            
            chunk_out = self.decoder(app_features)
            
            chunk_out = torch.sigmoid(chunk_out)*(1 + 2*0.001) - 0.001 
            
            rgbs = chunk_out.clamp(0,1)
            outs.append(chunk_out)
            
        rgbs=torch.cat(outs,1)
        rgbs=rgbs[:,:,0,:]
        
        albedo=rgbs[:,:,3:6]
        rgb=rgbs[:,:,0:3]
    
        results = {
            'shading':rgb,
            'albedo':albedo,
            'rgb':rgb*albedo,
        }
        return results # rgb, sigma, alpha, weight, bg_weight