import time
import torch
import trimesh
import numpy as np
import torch.optim as optim
from torch import autograd
from torch.utils.data import TensorDataset, DataLoader

from .common import make_3d_grid
from .utils import libmcubes
from .utils.libmise import MISE
from .utils.libsimplify import simplify_mesh
from .common import transform_pointcloud


class Generator3D(object):
    '''  Generator class for DVRs.

    It provides functions to generate the final mesh as well refining options.

    Args:
        model (nn.Module): trained DVR model
        points_batch_size (int): batch size for points evaluation
        threshold (float): threshold value
        refinement_step (int): number of refinement steps
        device (device): pytorch device
        resolution0 (int): start resolution for MISE
        upsampling steps (int): number of upsampling steps
        with_normals (bool): whether normals should be estimated
        padding (float): how much padding should be used for MISE
        simplify_nfaces (int): number of faces the mesh should be simplified to
        refine_max_faces (int): max number of faces which are used as batch
            size for refinement process (we added this functionality in this
            work)
    '''
    def __init__(
        self,
        model,
        points_batch_size=100000,
        threshold=0.5,
        refinement_step=0,
        device=None,
        resolution0=16,
        upsampling_steps=3,
        with_normals=False,
        padding=0.1,
        simplify_nfaces=None,
        with_color=False,
        refine_max_faces=10000
    ):
        self.model = model.to(device)
        self.points_batch_size = points_batch_size
        self.refinement_step = refinement_step
        self.threshold = threshold
        self.device = device
        self.resolution0 = resolution0
        self.upsampling_steps = upsampling_steps
        self.with_normals = with_normals
        self.padding = padding
        self.simplify_nfaces = simplify_nfaces
        self.with_color = with_color
        self.refine_max_faces = refine_max_faces

    def generate_mesh(self, data, return_stats=True):
        ''' Generates the output mesh.

        Args:
            data (tensor): data tensor
            return_stats (bool): whether stats should be returned
        '''
        self.model.eval()
        device = self.device
        stats_dict = {}

        inputs = data.get('inputs', torch.empty(1, 0)).to(device)
        kwargs = {}

        c = self.model.encode_inputs(inputs)
        mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs)

        return mesh, stats_dict

    def generate_meshes(self, data, return_stats=True):
        ''' Generates the output meshes with data of batch size >=1

        Args:
            data (tensor): data tensor
            return_stats (bool): whether stats should be returned
        '''
        self.model.eval()
        device = self.device
        stats_dict = {}

        inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device)

        meshes = []
        for i in range(inputs.shape[0]):
            input_i = inputs[i].unsqueeze(0)
            c = self.model.encode_inputs(input_i)
            mesh = self.generate_from_latent(c, stats_dict=stats_dict)
            meshes.append(mesh)

        return meshes

    def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True):
        ''' Generates a point cloud from the mesh.

        Args:
            mesh (trimesh): mesh
            data (dict): data dictionary
            n_points (int): number of point cloud points
            scale_back (bool): whether to undo scaling (requires a scale
                matrix in data dictionary)
        '''
        pcl = mesh.sample(n_points).astype(np.float32)

        if scale_back:
            scale_mat = data.get('camera.scale_mat_0', None)
            if scale_mat is not None:
                pcl = transform_pointcloud(pcl, scale_mat[0])
            else:
                print('Warning: No scale_mat found!')
        pcl_out = trimesh.Trimesh(vertices=pcl, process=False)
        return pcl_out

    def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs):
        ''' Generates mesh from latent.

        Args:
            c (tensor): latent conditioned code c
            pl (tensor): predicted plane parameters
            stats_dict (dict): stats dictionary
        '''
        threshold = np.log(self.threshold) - np.log(1. - self.threshold)

        t0 = time.time()
        # Compute bounding box size
        box_size = 1 + self.padding

        # Shortcut
        if self.upsampling_steps == 0:
            nx = self.resolution0
            pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3)
            values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
            value_grid = values.reshape(nx, nx, nx)
        else:
            mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold)

            points = mesh_extractor.query()

            while points.shape[0] != 0:
                # Query points
                pointsf = torch.FloatTensor(points).to(self.device)
                # Normalize to bounding box
                pointsf = 2 * pointsf / mesh_extractor.resolution
                pointsf = box_size * (pointsf - 1.0)
                # Evaluate model and update
                values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()

                values = values.astype(np.float64)
                mesh_extractor.update(points, values)
                points = mesh_extractor.query()

            value_grid = mesh_extractor.to_dense()

        # Extract mesh
        stats_dict['time (eval points)'] = time.time() - t0

        mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict)
        return mesh

    def eval_points(self, p, c=None, pl=None, **kwargs):
        ''' Evaluates the occupancy values for the points.

        Args:
            p (tensor): points
            c (tensor): latent conditioned code c
        '''
        p_split = torch.split(p, self.points_batch_size)
        occ_hats = []

        for pi in p_split:
            pi = pi.unsqueeze(0).to(self.device)
            with torch.no_grad():
                occ_hat = self.model.decode(pi, c, pl, **kwargs).logits

            occ_hats.append(occ_hat.squeeze(0).detach().cpu())

        occ_hat = torch.cat(occ_hats, dim=0)

        return occ_hat

    def extract_mesh(self, occ_hat, c=None, stats_dict=dict()):
        ''' Extracts the mesh from the predicted occupancy grid.

        Args:
            occ_hat (tensor): value grid of occupancies
            c (tensor): latent conditioned code c
            stats_dict (dict): stats dictionary
        '''
        # Some short hands
        n_x, n_y, n_z = occ_hat.shape
        box_size = 1 + self.padding
        threshold = np.log(self.threshold) - np.log(1. - self.threshold)
        # Make sure that mesh is watertight
        t0 = time.time()
        occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
        vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold)
        stats_dict['time (marching cubes)'] = time.time() - t0
        # Strange behaviour in libmcubes: vertices are shifted by 0.5
        vertices -= 0.5
        # Undo padding
        vertices -= 1
        # Normalize to bounding box
        vertices /= np.array([n_x - 1, n_y - 1, n_z - 1])
        vertices *= 2
        vertices = box_size * (vertices - 1)

        # mesh_pymesh = pymesh.form_mesh(vertices, triangles)
        # mesh_pymesh = fix_pymesh(mesh_pymesh)

        # Estimate normals if needed
        if self.with_normals and not vertices.shape[0] == 0:
            t0 = time.time()
            normals = self.estimate_normals(vertices, c)
            stats_dict['time (normals)'] = time.time() - t0
        else:
            normals = None
        # Create mesh
        mesh = trimesh.Trimesh(
            vertices,
            triangles,
            vertex_normals=normals,
        # vertex_colors=vertex_colors,
            process=False
        )

        # Directly return if mesh is empty
        if vertices.shape[0] == 0:
            return mesh

        # TODO: normals are lost here
        if self.simplify_nfaces is not None:
            t0 = time.time()
            mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.)
            stats_dict['time (simplify)'] = time.time() - t0

        # Refine mesh
        if self.refinement_step > 0:
            t0 = time.time()
            self.refine_mesh(mesh, occ_hat, c)
            stats_dict['time (refine)'] = time.time() - t0

        # Estimate Vertex Colors
        if self.with_color and not vertices.shape[0] == 0:
            t0 = time.time()
            vertex_colors = self.estimate_colors(np.array(mesh.vertices), c)
            stats_dict['time (color)'] = time.time() - t0
            mesh = trimesh.Trimesh(
                vertices=mesh.vertices,
                faces=mesh.faces,
                vertex_normals=mesh.vertex_normals,
                vertex_colors=vertex_colors,
                process=False
            )

        return mesh

    def estimate_colors(self, vertices, c=None):
        ''' Estimates vertex colors by evaluating the texture field.

        Args:
            vertices (numpy array): vertices of the mesh
            c (tensor): latent conditioned code c
        '''
        device = self.device
        vertices = torch.FloatTensor(vertices)
        vertices_split = torch.split(vertices, self.points_batch_size)
        colors = []
        for vi in vertices_split:
            vi = vi.to(device)
            with torch.no_grad():
                ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu()
            colors.append(ci)

        colors = np.concatenate(colors, axis=0)
        colors = np.clip(colors, 0, 1)
        colors = (colors * 255).astype(np.uint8)
        colors = np.concatenate(
            [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1
        )
        return colors

    def estimate_normals(self, vertices, c=None):
        ''' Estimates the normals by computing the gradient of the objective.

        Args:
            vertices (numpy array): vertices of the mesh
            z (tensor): latent code z
            c (tensor): latent conditioned code c
        '''
        device = self.device
        vertices = torch.FloatTensor(vertices)
        vertices_split = torch.split(vertices, self.points_batch_size)

        normals = []
        c = c.unsqueeze(0)
        for vi in vertices_split:
            vi = vi.unsqueeze(0).to(device)
            vi.requires_grad_()
            occ_hat = self.model.decode(vi, c).logits
            out = occ_hat.sum()
            out.backward()
            ni = -vi.grad
            ni = ni / torch.norm(ni, dim=-1, keepdim=True)
            ni = ni.squeeze(0).cpu().numpy()
            normals.append(ni)

        normals = np.concatenate(normals, axis=0)
        return normals

    def refine_mesh(self, mesh, occ_hat, c=None):
        ''' Refines the predicted mesh.

        Args:   
            mesh (trimesh object): predicted mesh
            occ_hat (tensor): predicted occupancy grid
            c (tensor): latent conditioned code c
        '''

        self.model.eval()

        # Some shorthands
        n_x, n_y, n_z = occ_hat.shape
        assert (n_x == n_y == n_z)
        # threshold = np.log(self.threshold) - np.log(1. - self.threshold)
        threshold = self.threshold

        # Vertex parameter
        v0 = torch.FloatTensor(mesh.vertices).to(self.device)
        v = torch.nn.Parameter(v0.clone())

        # Faces of mesh
        faces = torch.LongTensor(mesh.faces)

        # detach c; otherwise graph needs to be retained
        # caused by new Pytorch version?
        c = c.detach()

        # Start optimization
        optimizer = optim.RMSprop([v], lr=1e-5)

        # Dataset
        ds_faces = TensorDataset(faces)
        dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True)

        # We updated the refinement algorithm to subsample faces; this is
        # usefull when using a high extraction resolution / when working on
        # small GPUs
        it_r = 0
        while it_r < self.refinement_step:
            for f_it in dataloader:
                f_it = f_it[0].to(self.device)
                optimizer.zero_grad()

                # Loss
                face_vertex = v[f_it]
                eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0])
                eps = torch.FloatTensor(eps).to(self.device)
                face_point = (face_vertex * eps[:, :, None]).sum(dim=1)

                face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :]
                face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :]
                face_normal = torch.cross(face_v1, face_v2)
                face_normal = face_normal / \
                    (face_normal.norm(dim=1, keepdim=True) + 1e-10)

                face_value = torch.cat(
                    [
                        torch.sigmoid(self.model.decode(p_split, c).logits)
                        for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
                    ],
                    dim=1
                )

                normal_target = -autograd.grad([face_value.sum()], [face_point],
                                               create_graph=True)[0]

                normal_target = \
                    normal_target / \
                    (normal_target.norm(dim=1, keepdim=True) + 1e-10)
                loss_target = (face_value - threshold).pow(2).mean()
                loss_normal = \
                    (face_normal - normal_target).pow(2).sum(dim=1).mean()

                loss = loss_target + 0.01 * loss_normal

                # Update
                loss.backward()
                optimizer.step()

                # Update it_r
                it_r += 1

                if it_r >= self.refinement_step:
                    break

        mesh.vertices = v.data.cpu().numpy()
        return mesh