"""
Author: Yao Feng
Copyright (c) 2020, Yao Feng
All rights reserved.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.io import imread
import imageio

from . import util


def set_rasterizer(type='pytorch3d'):
    if type == 'pytorch3d':
        global Meshes, load_obj, rasterize_meshes
        from pytorch3d.structures import Meshes
        from pytorch3d.io import load_obj
        from pytorch3d.renderer.mesh import rasterize_meshes
    elif type == 'standard':
        global standard_rasterize, load_obj
        import os
        from .util import load_obj
        # Use JIT Compiling Extensions
        # ref: https://pytorch.org/tutorials/advanced/cpp_extension.html
        from torch.utils.cpp_extension import load, CUDA_HOME
        curr_dir = os.path.dirname(__file__)
        standard_rasterize_cuda = \
            load(name='standard_rasterize_cuda',
                 sources=[f'{curr_dir}/rasterizer/standard_rasterize_cuda.cpp',
                          f'{curr_dir}/rasterizer/standard_rasterize_cuda_kernel.cu'],
                 extra_cuda_cflags=['-std=c++14', '-ccbin=$$(which gcc-7)'])  # cuda10.2 is not compatible with gcc9. Specify gcc 7
        from standard_rasterize_cuda import standard_rasterize
        # If JIT does not work, try manually installation first
        # 1. see instruction here: pixielib/utils/rasterizer/INSTALL.md
        # 2. add this: "from .rasterizer.standard_rasterize_cuda import standard_rasterize" here


class StandardRasterizer(nn.Module):
    """ Alg: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation
    Notice:
        x,y,z are in image space, normalized to [-1, 1]
        can render non-squared image
        not differentiable
    """

    def __init__(self, height, width=None):
        """
        use fixed raster_settings for rendering faces
        """
        super().__init__()
        if width is None:
            width = height
        self.h = h = height
        self.w = w = width

    def forward(self, vertices, faces, attributes=None, h=None, w=None):
        device = vertices.device
        if h is None:
            h = self.h
        if w is None:
            w = self.h
        bz = vertices.shape[0]
        depth_buffer = torch.zeros([bz, h, w]).float().to(device) + 1e6
        triangle_buffer = torch.zeros([bz, h, w]).int().to(device) - 1
        baryw_buffer = torch.zeros([bz, h, w, 3]).float().to(device)
        vert_vis = torch.zeros([bz, vertices.shape[1]]).float().to(device)

        vertices = vertices.clone().float()
        vertices[..., 0] = vertices[..., 0] * w / 2 + w / 2
        vertices[..., 1] = vertices[..., 1] * h / 2 + h / 2
        vertices[..., 2] = vertices[..., 2] * w / 2
        f_vs = util.face_vertices(vertices, faces)

        standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer,
                           h, w)
        pix_to_face = triangle_buffer[:, :, :, None].long()
        bary_coords = baryw_buffer[:, :, :, None, :]
        vismask = (pix_to_face > -1).float()
        D = attributes.shape[-1]
        attributes = attributes.clone()
        attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
                                     3, attributes.shape[-1])
        N, H, W, K, _ = bary_coords.shape
        mask = pix_to_face == -1
        pix_to_face = pix_to_face.clone()
        pix_to_face[mask] = 0
        idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
        pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
        pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
        pixel_vals[mask] = 0  # Replace masked values in output.
        pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
        pixel_vals = torch.cat(
            [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
        return pixel_vals


class Pytorch3dRasterizer(nn.Module):
    """  Borrowed from https://github.com/facebookresearch/pytorch3d
    This class implements methods for rasterizing a batch of heterogenous Meshes.
    Notice:
        x,y,z are in image space, normalized
        can only render squared image now
    """

    def __init__(self, image_size=224):
        """
        use fixed raster_settings for rendering faces
        """
        super().__init__()
        raster_settings = {
            'image_size': image_size,
            'blur_radius': 0.0,
            'faces_per_pixel': 1,
            'bin_size': None,
            'max_faces_per_bin': None,
            'perspective_correct': False,
        }
        raster_settings = util.dict2obj(raster_settings)
        self.raster_settings = raster_settings

    def forward(self, vertices, faces, attributes=None, h=None, w=None):
        fixed_vertices = vertices.clone()
        fixed_vertices[..., :2] = -fixed_vertices[..., :2]
        meshes_screen = Meshes(verts=fixed_vertices.float(),
                               faces=faces.long())
        raster_settings = self.raster_settings
        pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
            meshes_screen,
            image_size=raster_settings.image_size,
            blur_radius=raster_settings.blur_radius,
            faces_per_pixel=raster_settings.faces_per_pixel,
            bin_size=raster_settings.bin_size,
            max_faces_per_bin=raster_settings.max_faces_per_bin,
            perspective_correct=raster_settings.perspective_correct,
        )
        vismask = (pix_to_face > -1).float()
        D = attributes.shape[-1]
        attributes = attributes.clone()
        attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
                                     3, attributes.shape[-1])
        N, H, W, K, _ = bary_coords.shape
        mask = pix_to_face == -1
        pix_to_face = pix_to_face.clone()
        pix_to_face[mask] = 0
        idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
        pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
        pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
        pixel_vals[mask] = 0  # Replace masked values in output.
        pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
        pixel_vals = torch.cat(
            [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
        return pixel_vals


class SRenderY(nn.Module):

    def __init__(self,
                 image_size,
                 obj_filename,
                 uv_size=256,
                 rasterizer_type='standard'):
        super(SRenderY, self).__init__()
        self.image_size = image_size
        self.uv_size = uv_size

        if rasterizer_type == 'pytorch3d':
            self.rasterizer = Pytorch3dRasterizer(image_size)
            self.uv_rasterizer = Pytorch3dRasterizer(uv_size)
            verts, faces, aux = load_obj(obj_filename)
            uvcoords = aux.verts_uvs[None, ...]  # (N, V, 2)
            uvfaces = faces.textures_idx[None, ...]  # (N, F, 3)
            faces = faces.verts_idx[None, ...]
        elif rasterizer_type == 'standard':
            self.rasterizer = StandardRasterizer(image_size)
            self.uv_rasterizer = StandardRasterizer(uv_size)
            verts, uvcoords, faces, uvfaces = load_obj(obj_filename)
            verts = verts[None, ...]
            uvcoords = uvcoords[None, ...]
            faces = faces[None, ...]
            uvfaces = uvfaces[None, ...]
        else:
            NotImplementedError

        # faces
        dense_triangles = util.generate_triangles(uv_size, uv_size)
        self.register_buffer(
            'dense_faces',
            torch.from_numpy(dense_triangles).long()[None, :, :])
        self.register_buffer('faces', faces)
        self.register_buffer('raw_uvcoords', uvcoords)

        # uv coords
        uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0. + 1.],
                             -1)  # [bz, ntv, 3]
        uvcoords = uvcoords * 2 - 1
        uvcoords[..., 1] = -uvcoords[..., 1]
        face_uvcoords = util.face_vertices(uvcoords, uvfaces)
        self.register_buffer('uvcoords', uvcoords)
        self.register_buffer('uvfaces', uvfaces)
        self.register_buffer('face_uvcoords', face_uvcoords)

        # shape colors, for rendering shape overlay
        colors = torch.tensor([180, 180, 180])[None, None, :].repeat(
            1,
            faces.max() + 1, 1).float() / 255.
        face_colors = util.face_vertices(colors, faces)
        self.register_buffer('vertex_colors', colors)
        self.register_buffer('face_colors', face_colors)

        # SH factors for lighting
        pi = np.pi
        constant_factor = torch.tensor([
            1 / np.sqrt(4 * pi), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
            ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), ((2 * pi) / 3) *
            (np.sqrt(3 / (4 * pi))), (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
            (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
            (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
            (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
            (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi)))
        ]).float()
        self.register_buffer('constant_factor', constant_factor)

    def forward(self,
                vertices,
                transformed_vertices,
                albedos,
                lights=None,
                light_type='point',
                background=None,
                h=None,
                w=None):
        '''
        -- Texture Rendering
        vertices: [batch_size, V, 3], vertices in world space, for calculating normals, then shading
        transformed_vertices: [batch_size, V, 3], rnage:[-1,1], projected vertices, in image space, for rasterization
        albedos: [batch_size, 3, h, w], uv map
        lights: 
            spherical homarnic: [N, 9(shcoeff), 3(rgb)]
            points/directional lighting: [N, n_lights, 6(xyzrgb)]
        light_type:
            point or directional
        '''
        batch_size = vertices.shape[0]
        # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
        transformed_vertices = transformed_vertices.clone()
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] - transformed_vertices[:, :,
                                                                                 2].min(
                                                                                 )
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] / transformed_vertices[:, :,
                                                                                 2].max(
                                                                                 )
        transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10

        # attributes
        face_vertices = util.face_vertices(
            vertices, self.faces.expand(batch_size, -1, -1))
        normals = util.vertex_normals(vertices,
                                      self.faces.expand(batch_size, -1, -1))
        face_normals = util.face_vertices(
            normals, self.faces.expand(batch_size, -1, -1))
        transformed_normals = util.vertex_normals(
            transformed_vertices, self.faces.expand(batch_size, -1, -1))
        transformed_face_normals = util.face_vertices(
            transformed_normals, self.faces.expand(batch_size, -1, -1))
        attributes = torch.cat([
            self.face_uvcoords.expand(batch_size, -1, -1, -1),
            transformed_face_normals.detach(),
            face_vertices.detach(), face_normals
        ], -1)

        # rasterize
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes, h, w)

        ####
        # vis mask
        alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()

        # albedo
        uvcoords_images = rendering[:, :3, :, :]
        grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]
        albedo_images = F.grid_sample(albedos, grid, align_corners=False)

        # visible mask for pixels with positive normal direction
        transformed_normal_map = rendering[:, 3:6, :, :].detach()
        pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float()

        # shading
        normal_images = rendering[:, 9:12, :, :]
        if lights is not None:
            if lights.shape[1] == 9:
                shading_images = self.add_SHlight(normal_images, lights)
            else:
                if light_type == 'point':
                    vertice_images = rendering[:, 6:9, :, :].detach()
                    shading = self.add_pointlight(
                        vertice_images.permute(0, 2, 3,
                                               1).reshape([batch_size, -1, 3]),
                        normal_images.permute(0, 2, 3,
                                              1).reshape([batch_size, -1, 3]),
                        lights)
                    shading_images = shading.reshape([
                        batch_size, albedo_images.shape[2],
                        albedo_images.shape[3], 3
                    ]).permute(0, 3, 1, 2)
                else:
                    shading = self.add_directionlight(
                        normal_images.permute(0, 2, 3,
                                              1).reshape([batch_size, -1, 3]),
                        lights)
                    shading_images = shading.reshape([
                        batch_size, albedo_images.shape[2],
                        albedo_images.shape[3], 3
                    ]).permute(0, 3, 1, 2)
            images = albedo_images * shading_images
        else:
            images = albedo_images
            shading_images = images.detach() * 0.

        if background is None:
            images = images*alpha_images + \
                torch.ones_like(images).to(vertices.device)*(1-alpha_images)
        else:
            # background = F.interpolate(background, [self.image_size, self.image_size])
            images = images * alpha_images + background.contiguous() * (
                1 - alpha_images)

        outputs = {
            'images': images,
            'albedo_images': albedo_images,
            'alpha_images': alpha_images,
            'pos_mask': pos_mask,
            'shading_images': shading_images,
            'grid': grid,
            'normals': normals,
            'normal_images': normal_images,
            'transformed_normals': transformed_normals,
        }

        return outputs

    def add_SHlight(self, normal_images, sh_coeff):
        '''
            sh_coeff: [bz, 9, 3]
        '''
        N = normal_images
        sh = torch.stack([
            N[:, 0] * 0. + 1., N[:, 0], N[:, 1], N[:, 2], N[:, 0] * N[:, 1],
            N[:, 0] * N[:, 2], N[:, 1] * N[:, 2], N[:, 0]**2 - N[:, 1]**2, 3 *
            (N[:, 2]**2) - 1
        ], 1)  # [bz, 9, h, w]
        sh = sh * self.constant_factor[None, :, None, None]
        # [bz, 9, 3, h, w]
        shading = torch.sum(
            sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1)
        return shading

    def add_pointlight(self, vertices, normals, lights):
        '''
            vertices: [bz, nv, 3]
            lights: [bz, nlight, 6]
        returns:
            shading: [bz, nv, 3]
        '''
        light_positions = lights[:, :, :3]
        light_intensities = lights[:, :, 3:]
        directions_to_lights = F.normalize(light_positions[:, :, None, :] -
                                           vertices[:, None, :, :],
                                           dim=3)
        # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
        normals_dot_lights = (normals[:, None, :, :] *
                              directions_to_lights).sum(dim=3)
        shading = normals_dot_lights[:, :, :,
                                     None] * light_intensities[:, :, None, :]
        return shading.mean(1)

    def add_directionlight(self, normals, lights):
        '''
            normals: [bz, nv, 3]
            lights: [bz, nlight, 6]
        returns:
            shading: [bz, nv, 3]
        '''
        light_direction = lights[:, :, :3]
        light_intensities = lights[:, :, 3:]
        directions_to_lights = F.normalize(
            light_direction[:, :, None, :].expand(-1, -1, normals.shape[1],
                                                  -1),
            dim=3)
        # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
        # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3)
        normals_dot_lights = torch.clamp(
            (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0., 1.)
        shading = normals_dot_lights[:, :, :,
                                     None] * light_intensities[:, :, None, :]
        return shading.mean(1)

    def render_shape(self,
                     vertices,
                     transformed_vertices,
                     colors=None,
                     background=None,
                     detail_normal_images=None,
                     lights=None,
                     return_grid=False,
                     uv_detail_normals=None,
                     h=None,
                     w=None):
        '''
        -- rendering shape with detail normal map
        '''
        batch_size = vertices.shape[0]
        if lights is None:
            light_positions = torch.tensor([
                [-5, 5, -5],
                [5, 5, -5],
                [-5, -5, -5],
                [5, -5, -5],
                [0, 0, -5],
            ])[None, :, :].expand(batch_size, -1, -1).float()

            light_intensities = torch.ones_like(light_positions).float() * 1.7
            lights = torch.cat((light_positions, light_intensities),
                               2).to(vertices.device)
        # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
        transformed_vertices = transformed_vertices.clone()
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] - transformed_vertices[:, :,
                                                                                 2].min(
                                                                                 )
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] / transformed_vertices[:, :,
                                                                                 2].max(
                                                                                 )
        transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10

        # Attributes
        face_vertices = util.face_vertices(
            vertices, self.faces.expand(batch_size, -1, -1))
        normals = util.vertex_normals(vertices,
                                      self.faces.expand(batch_size, -1, -1))
        face_normals = util.face_vertices(
            normals, self.faces.expand(batch_size, -1, -1))
        transformed_normals = util.vertex_normals(
            transformed_vertices, self.faces.expand(batch_size, -1, -1))
        transformed_face_normals = util.face_vertices(
            transformed_normals, self.faces.expand(batch_size, -1, -1))
        if colors is None:
            colors = self.face_colors.expand(batch_size, -1, -1, -1)
        attributes = torch.cat([
            colors,
            transformed_face_normals.detach(),
            face_vertices.detach(), face_normals,
            self.face_uvcoords.expand(batch_size, -1, -1, -1)
        ], -1)
        # rasterize
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes, h, w)

        ####
        alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()

        # albedo
        albedo_images = rendering[:, :3, :, :]
        # mask
        transformed_normal_map = rendering[:, 3:6, :, :].detach()
        pos_mask = (transformed_normal_map[:, 2:, :, :] < 0).float()

        # shading
        normal_images = rendering[:, 9:12, :, :].detach()
        vertice_images = rendering[:, 6:9, :, :].detach()
        if detail_normal_images is not None:
            normal_images = detail_normal_images
        if uv_detail_normals is not None:
            uvcoords_images = rendering[:, 12:15, :, :]
            grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]
            detail_normal_images = F.grid_sample(uv_detail_normals,
                                                 grid,
                                                 align_corners=False)
            normal_images = detail_normal_images

        shading = self.add_directionlight(
            normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
            lights)
        shading_images = shading.reshape(
            [batch_size, albedo_images.shape[2], albedo_images.shape[3],
             3]).permute(0, 3, 1, 2).contiguous()
        shaded_images = albedo_images * shading_images

        if background is None:
            shape_images = shaded_images*alpha_images + \
                torch.ones_like(shaded_images).to(
                    vertices.device)*(1-alpha_images)
        else:
            # background = F.interpolate(background, [self.image_size, self.image_size])
            shape_images = shaded_images*alpha_images + \
                background.contiguous()*(1-alpha_images)

        if return_grid:
            uvcoords_images = rendering[:, 12:15, :, :]
            grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]
            return shape_images, normal_images, grid
        else:
            return shape_images

    def render_depth(self, transformed_vertices):
        '''
        -- rendering depth
        '''
        transformed_vertices = transformed_vertices.clone()
        batch_size = transformed_vertices.shape[0]

        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] - transformed_vertices[:, :,
                                                                                 2].min(
                                                                                 )
        z = -transformed_vertices[:, :, 2:].repeat(1, 1, 3)
        z = z - z.min()
        z = z / z.max()
        # Attributes
        attributes = util.face_vertices(z,
                                        self.faces.expand(batch_size, -1, -1))
        # rasterize
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes)

        ####
        alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()
        depth_images = rendering[:, :1, :, :]
        return depth_images

    def render_colors(self, transformed_vertices, colors, h=None, w=None):
        '''
        -- rendering colors: could be rgb color/ normals, etc
            colors: [bz, num of vertices, 3]
        '''
        transformed_vertices = transformed_vertices.clone()
        batch_size = colors.shape[0]
        # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100)
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] - transformed_vertices[:, :,
                                                                                 2].min(
                                                                                 )
        transformed_vertices[:, :,
                             2] = transformed_vertices[:, :,
                                                       2] / transformed_vertices[:, :,
                                                                                 2].max(
                                                                                 )
        transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10
        # Attributes
        attributes = util.face_vertices(colors,
                                        self.faces.expand(batch_size, -1, -1))
        # rasterize
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes,
                                    h=h,
                                    w=w)
        ####
        alpha_images = rendering[:, [-1], :, :].detach()
        images = rendering[:, :3, :, :] * alpha_images
        return images

    def world2uv(self, vertices):
        '''
        project vertices from world space to uv space
        vertices: [bz, V, 3]
        uv_vertices: [bz, 3, h, w]
        '''
        batch_size = vertices.shape[0]
        face_vertices = util.face_vertices(
            vertices, self.faces.expand(batch_size, -1, -1))
        uv_vertices = self.uv_rasterizer(
            self.uvcoords.expand(batch_size, -1, -1),
            self.uvfaces.expand(batch_size, -1, -1), face_vertices)[:, :3]
        return uv_vertices