# This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py

import torch
import scipy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from lib.pymafx.core import path_config
from lib.pymafx.utils.geometry import projection

import logging

logger = logging.getLogger(__name__)

from .transformers.net_utils import PosEnSine
from .transformers.transformer_basics import OurMultiheadAttention

from lib.pymafx.utils.imutils import j2d_processing


class TransformerDecoderUnit(nn.Module):
    def __init__(
        self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None
    ):
        super(TransformerDecoderUnit, self).__init__()
        self.feat_dim = feat_dim
        self.attn_type = attn_type
        self.pos_en_flag = pos_en_flag
        self.P = P

        assert attri_dim == 0
        if self.pos_en_flag:
            pe_dim = 10
            self.pos_en = PosEnSine(pe_dim)
        else:
            pe_dim = 0
        self.attn = OurMultiheadAttention(
            feat_dim + attri_dim + pe_dim * 3, feat_dim + pe_dim * 3, feat_dim, n_head
        )    # cross-attention

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm = nn.BatchNorm2d(self.feat_dim)

    def forward(self, q, k, v, pos=None):
        if self.pos_en_flag:
            q_pos_embed = self.pos_en(q, pos)
            k_pos_embed = self.pos_en(k)

            q = torch.cat([q, q_pos_embed], dim=1)
            k = torch.cat([k, k_pos_embed], dim=1)
        # else:
        #     q_pos_embed = 0
        #     k_pos_embed = 0

        # cross-multi-head attention
        out = self.attn(q=q, k=k, v=v, attn_type=self.attn_type, P=self.P)[0]

        # feed forward
        out2 = self.linear2(self.activation(self.linear1(out)))
        out = out + out2
        out = self.norm(out)

        return out


class Mesh_Sampler(nn.Module):
    ''' Mesh Up/Down-sampling
    '''
    def __init__(self, type='smpl', level=2, device=torch.device('cuda'), option=None):
        super().__init__()

        # downsample SMPL mesh and assign part labels
        if type == 'smpl':
            # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
            smpl_mesh_graph = np.load(
                path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
            )

            A = smpl_mesh_graph['A']
            U = smpl_mesh_graph['U']
            D = smpl_mesh_graph['D']    # shape: (2,)
        elif type == 'mano':
            # from https://github.com/microsoft/MeshGraphormer/blob/main/src/modeling/data/mano_downsampling.npz
            mano_mesh_graph = np.load(
                path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
            )

            A = mano_mesh_graph['A']
            U = mano_mesh_graph['U']
            D = mano_mesh_graph['D']    # shape: (2,)

        # downsampling
        ptD = []
        for lv in range(len(D)):
            d = scipy.sparse.coo_matrix(D[lv])
            i = torch.LongTensor(np.array([d.row, d.col]))
            v = torch.FloatTensor(d.data)
            ptD.append(torch.sparse.FloatTensor(i, v, d.shape))

        # downsampling mapping from 6890 points to 431 points
        # ptD[0].to_dense() - Size: [1723, 6890] , [195, 778]
        # ptD[1].to_dense() - Size: [431, 1723] , [49, 195]
        if level == 2:
            Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense())    # 6890 -> 431
        elif level == 1:
            Dmap = ptD[0].to_dense()    #
        self.register_buffer('Dmap', Dmap)

        # upsampling
        ptU = []
        for lv in range(len(U)):
            d = scipy.sparse.coo_matrix(U[lv])
            i = torch.LongTensor(np.array([d.row, d.col]))
            v = torch.FloatTensor(d.data)
            ptU.append(torch.sparse.FloatTensor(i, v, d.shape))

        # upsampling mapping from 431 points to 6890 points
        # ptU[0].to_dense() - Size: [6890, 1723]
        # ptU[1].to_dense() - Size: [1723, 431]
        if level == 2:
            Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense())    # 431 -> 6890
        elif level == 1:
            Umap = ptU[0].to_dense()    #
        self.register_buffer('Umap', Umap)

    def downsample(self, x):
        return torch.matmul(self.Dmap.unsqueeze(0), x)    # [B, 431, 3]

    def upsample(self, x):
        return torch.matmul(self.Umap.unsqueeze(0), x)    # [B, 6890, 3]

    def forward(self, x, mode='downsample'):
        if mode == 'downsample':
            return self.downsample(x)
        elif mode == 'upsample':
            return self.upsample(x)


class MAF_Extractor(nn.Module):
    ''' Mesh-aligned Feature Extrator
    As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices.
    The features extrated from spatial feature maps will go through a MLP for dimension reduction.
    '''
    def __init__(
        self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None
    ):
        super().__init__()

        self.device = device
        self.filters = []
        self.num_views = 1
        self.last_op = nn.ReLU(True)

        self.iwp_cam_mode = iwp_cam_mode

        for l in range(0, len(filter_channels) - 1):
            if 0 != l:
                self.filters.append(
                    nn.Conv1d(filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1)
                )
            else:
                self.filters.append(nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))

            self.add_module("conv%d" % l, self.filters[l])

        # downsample SMPL mesh and assign part labels
        # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
        smpl_mesh_graph = np.load(
            path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1'
        )

        A = smpl_mesh_graph['A']
        U = smpl_mesh_graph['U']
        D = smpl_mesh_graph['D']    # shape: (2,)

        # downsampling
        ptD = []
        for level in range(len(D)):
            d = scipy.sparse.coo_matrix(D[level])
            i = torch.LongTensor(np.array([d.row, d.col]))
            v = torch.FloatTensor(d.data)
            ptD.append(torch.sparse.FloatTensor(i, v, d.shape))

        # downsampling mapping from 6890 points to 431 points
        # ptD[0].to_dense() - Size: [1723, 6890]
        # ptD[1].to_dense() - Size: [431. 1723]
        Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense())    # 6890 -> 431
        self.register_buffer('Dmap', Dmap)

        # upsampling
        ptU = []
        for level in range(len(U)):
            d = scipy.sparse.coo_matrix(U[level])
            i = torch.LongTensor(np.array([d.row, d.col]))
            v = torch.FloatTensor(d.data)
            ptU.append(torch.sparse.FloatTensor(i, v, d.shape))

        # upsampling mapping from 431 points to 6890 points
        # ptU[0].to_dense() - Size: [6890, 1723]
        # ptU[1].to_dense() - Size: [1723, 431]
        Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense())    # 431 -> 6890
        self.register_buffer('Umap', Umap)

    def reduce_dim(self, feature):
        '''
        Dimension reduction by multi-layer perceptrons
        :param feature: list of [B, C_s, N] point-wise features before dimension reduction
        :return: [B, C_p x N] concatantion of point-wise features after dimension reduction
        '''
        y = feature
        tmpy = feature
        for i, f in enumerate(self.filters):
            y = self._modules['conv' + str(i)](y if i == 0 else torch.cat([y, tmpy], 1))
            if i != len(self.filters) - 1:
                y = F.leaky_relu(y)
            if self.num_views > 1 and i == len(self.filters) // 2:
                y = y.view(-1, self.num_views, y.shape[1], y.shape[2]).mean(dim=1)
                tmpy = feature.view(-1, self.num_views, feature.shape[1],
                                    feature.shape[2]).mean(dim=1)

        y = self.last_op(y)

        # y = y.view(y.shape[0], -1)

        return y

    def sampling(self, points, im_feat=None, z_feat=None, add_att=False, reduce_dim=True):
        '''
        Given 2D points, sample the point-wise features for each point, 
        the dimension of point-wise features will be reduced from C_s to C_p by MLP.
        Image features should be pre-computed before this call.
        :param points: [B, N, 2] image coordinates of points
        :im_feat: [B, C_s, H_s, W_s] spatial feature maps 
        :return: [B, C_p x N] concatantion of point-wise features after dimension reduction
        '''
        # if im_feat is None:
        #     im_feat = self.im_feat

        batch_size = im_feat.shape[0]
        point_feat = torch.nn.functional.grid_sample(
            im_feat, points.unsqueeze(2), align_corners=False
        )[..., 0]

        if reduce_dim:
            mesh_align_feat = self.reduce_dim(point_feat)
            return mesh_align_feat
        else:
            return point_feat

    def forward(self, p, im_feat, cam=None, add_att=False, reduce_dim=True, **kwargs):
        ''' Returns mesh-aligned features for the 3D mesh points.
        Args:
            p (tensor): [B, N_m, 3] mesh vertices
            im_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps
            cam (tensor): [B, 3] camera
        Return:
            mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features
        '''
        # if cam is None:
        #     cam = self.cam
        p_proj_2d = projection(p, cam, retain_z=False, iwp_mode=self.iwp_cam_mode)
        if self.iwp_cam_mode:
            # Normalize keypoints to [-1,1]
            p_proj_2d = p_proj_2d / (224. / 2.)
        else:
            p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf'])
        mesh_align_feat = self.sampling(p_proj_2d, im_feat, add_att=add_att, reduce_dim=reduce_dim)
        return mesh_align_feat