import torch
import torch.nn as nn
import torch.nn.functional as F


def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None):

    Index_UV_max = torch.argmax(Index_UV, dim=1).float()
    recon_Index_UV = []
    for i in range(Index_UV.size(1)):
        if i == 0:
            recon_Index_UV_i = torch.min(
                F.threshold(Index_UV_max + 1, 0.5, 0), -F.threshold(-Index_UV_max - 1, -1.5, 0)
            )
        else:
            recon_Index_UV_i = torch.min(
                F.threshold(Index_UV_max, i - 0.5, 0), -F.threshold(-Index_UV_max, -i - 0.5, 0)
            ) / float(i)
        recon_Index_UV.append(recon_Index_UV_i)
    recon_Index_UV = torch.stack(recon_Index_UV, dim=1)

    if AnnIndex is None:
        recon_Ann_Index = None
    else:
        AnnIndex_max = torch.argmax(AnnIndex, dim=1).float()
        recon_Ann_Index = []
        for i in range(AnnIndex.size(1)):
            if i == 0:
                recon_Ann_Index_i = torch.min(
                    F.threshold(AnnIndex_max + 1, 0.5, 0), -F.threshold(-AnnIndex_max - 1, -1.5, 0)
                )
            else:
                recon_Ann_Index_i = torch.min(
                    F.threshold(AnnIndex_max, i - 0.5, 0), -F.threshold(-AnnIndex_max, -i - 0.5, 0)
                ) / float(i)
            recon_Ann_Index.append(recon_Ann_Index_i)
        recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1)

    recon_U = recon_Index_UV * U_uv
    recon_V = recon_Index_UV * V_uv

    return recon_U, recon_V, recon_Index_UV, recon_Ann_Index


def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=None, n_part=24):
    device_id = U_uv.get_device()
    batch_size = U_uv.size(0)
    K = U_uv.size(1)
    heatmap_size = U_uv.size(2)

    Index_UV_max = torch.argmax(Index_UV, dim=1)
    if AnnIndex is None:
        Index_UV_max = Index_UV_max.to(torch.int64)
    else:
        AnnIndex_max = torch.argmax(AnnIndex, dim=1)
        Index_UV_max = Index_UV_max * (AnnIndex_max > 0).to(torch.int64)

    outputs = []

    for batch_id in range(batch_size):

        output = torch.zeros([3, U_uv.size(2), U_uv.size(3)], dtype=torch.float32).cuda(device_id)
        output[0] = Index_UV_max[batch_id].to(torch.float32)
        if ind_mapping is None:
            output[0] /= float(K - 1)
        else:
            for ind in range(len(ind_mapping)):
                output[0][output[0] == ind] = ind_mapping[ind] * (1. / n_part)

        for part_id in range(0, K):
            CurrentU = U_uv[batch_id, part_id]
            CurrentV = V_uv[batch_id, part_id]
            output[1,
                   Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id]
            output[2,
                   Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id]

        if uv_rois is None:
            outputs.append(output.unsqueeze(0))
        else:
            roi_fg = uv_rois[batch_id][1:]

            # x1 = roi_fg[0]
            # x2 = roi_fg[2]
            # y1 = roi_fg[1]
            # y2 = roi_fg[3]

            w = roi_fg[2] - roi_fg[0]
            h = roi_fg[3] - roi_fg[1]

            aspect_ratio = float(w) / h

            if aspect_ratio < 1:
                new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)]
                output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
                paddingleft = int(0.5 * (heatmap_size - new_size[1]))
                output = F.pad(
                    output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0)
                )
            else:
                new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size]
                output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest')
                paddingtop = int(0.5 * (heatmap_size - new_size[0]))
                output = F.pad(
                    output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop)
                )

            outputs.append(output)

    return torch.cat(outputs, dim=0)


def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24):
    device_id = uvimages.get_device()
    batch_size = uvimages.size(0)
    uvimg_size = uvimages.size(-1)

    Index2mask = [
        [0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18],
        [19, 21], [20, 22], [23, 24]
    ]

    part_ind = torch.round(uvimages[:, 0, :, :] * n_part)
    part_u = uvimages[:, 1, :, :]
    part_v = uvimages[:, 2, :, :]

    recon_U = []
    recon_V = []
    recon_Index_UV = []
    recon_Ann_Index = []

    for i in range(n_part + 1):
        if i == 0:
            recon_Index_UV_i = torch.min(
                F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)
            )
        else:
            recon_Index_UV_i = torch.min(
                F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0)
            ) / float(i)
        recon_U_i = recon_Index_UV_i * part_u
        recon_V_i = recon_Index_UV_i * part_v

        recon_Index_UV.append(recon_Index_UV_i)
        recon_U.append(recon_U_i)
        recon_V.append(recon_V_i)

    for i in range(len(Index2mask)):
        if len(Index2mask[i]) == 1:
            recon_Ann_Index_i = recon_Index_UV[Index2mask[i][0]]
        elif len(Index2mask[i]) == 2:
            p_ind0 = Index2mask[i][0]
            p_ind1 = Index2mask[i][1]
            # recon_Ann_Index[:, i, :, :] = torch.where(recon_Index_UV[:, p_ind0, :, :] > 0.5, recon_Index_UV[:, p_ind0, :, :], recon_Index_UV[:, p_ind1, :, :])
            # recon_Ann_Index[:, i, :, :] = torch.eq(part_ind, p_ind0) | torch.eq(part_ind, p_ind1)
            recon_Ann_Index_i = recon_Index_UV[p_ind0] + recon_Index_UV[p_ind1]

        recon_Ann_Index.append(recon_Ann_Index_i)

    recon_U = torch.stack(recon_U, dim=1)
    recon_V = torch.stack(recon_V, dim=1)
    recon_Index_UV = torch.stack(recon_Index_UV, dim=1)
    recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1)

    if uv_rois is None:
        return recon_U, recon_V, recon_Index_UV, recon_Ann_Index

    recon_U_roi = []
    recon_V_roi = []
    recon_Index_UV_roi = []
    recon_Ann_Index_roi = []

    if new_size is None:
        M = uvimg_size
    else:
        M = new_size

    for i in range(batch_size):
        roi_fg = uv_rois[i][1:]

        # x1 = roi_fg[0]
        # x2 = roi_fg[2]
        # y1 = roi_fg[1]
        # y2 = roi_fg[3]

        w = roi_fg[2] - roi_fg[0]
        h = roi_fg[3] - roi_fg[1]

        aspect_ratio = float(w) / h

        if aspect_ratio < 1:
            w_size = max(int(uvimg_size * aspect_ratio), 1)
            w_margin = int((uvimg_size - w_size) / 2)

            recon_U_roi_i = recon_U[i, :, :, w_margin:w_margin + w_size]
            recon_V_roi_i = recon_V[i, :, :, w_margin:w_margin + w_size]
            recon_Index_UV_roi_i = recon_Index_UV[i, :, :, w_margin:w_margin + w_size]
            recon_Ann_Index_roi_i = recon_Ann_Index[i, :, :, w_margin:w_margin + w_size]
        else:
            h_size = max(int(uvimg_size / aspect_ratio), 1)
            h_margin = int((uvimg_size - h_size) / 2)

            recon_U_roi_i = recon_U[i, :, h_margin:h_margin + h_size, :]
            recon_V_roi_i = recon_V[i, :, h_margin:h_margin + h_size, :]
            recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :]
            recon_Ann_Index_roi_i = recon_Ann_Index[i, :, h_margin:h_margin + h_size, :]

        recon_U_roi_i = F.interpolate(recon_U_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
        recon_V_roi_i = F.interpolate(recon_V_roi_i.unsqueeze(0), size=(M, M), mode='nearest')
        recon_Index_UV_roi_i = F.interpolate(
            recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
        )
        recon_Ann_Index_roi_i = F.interpolate(
            recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
        )

        recon_U_roi.append(recon_U_roi_i)
        recon_V_roi.append(recon_V_roi_i)
        recon_Index_UV_roi.append(recon_Index_UV_roi_i)
        recon_Ann_Index_roi.append(recon_Ann_Index_roi_i)

    recon_U_roi = torch.cat(recon_U_roi, dim=0)
    recon_V_roi = torch.cat(recon_V_roi, dim=0)
    recon_Index_UV_roi = torch.cat(recon_Index_UV_roi, dim=0)
    recon_Ann_Index_roi = torch.cat(recon_Ann_Index_roi, dim=0)

    return recon_U_roi, recon_V_roi, recon_Index_UV_roi, recon_Ann_Index_roi


def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24):
    device_id = segimages.get_device()
    batch_size = segimages.size(0)
    uvimg_size = segimages.size(-1)

    part_ind = torch.round(segimages[:, 0, :, :] * n_part)

    recon_Index_UV = []

    for i in range(n_part + 1):
        if i == 0:
            recon_Index_UV_i = torch.min(
                F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)
            )
        else:
            recon_Index_UV_i = torch.min(
                F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0)
            ) / float(i)

        recon_Index_UV.append(recon_Index_UV_i)

    recon_Index_UV = torch.stack(recon_Index_UV, dim=1)

    if uv_rois is None:
        return None, None, recon_Index_UV, None

    recon_Index_UV_roi = []

    if new_size is None:
        M = uvimg_size
    else:
        M = new_size

    for i in range(batch_size):
        roi_fg = uv_rois[i][1:]

        # x1 = roi_fg[0]
        # x2 = roi_fg[2]
        # y1 = roi_fg[1]
        # y2 = roi_fg[3]

        w = roi_fg[2] - roi_fg[0]
        h = roi_fg[3] - roi_fg[1]

        aspect_ratio = float(w) / h

        if aspect_ratio < 1:
            w_size = max(int(uvimg_size * aspect_ratio), 1)
            w_margin = int((uvimg_size - w_size) / 2)

            recon_Index_UV_roi_i = recon_Index_UV[i, :, :, w_margin:w_margin + w_size]
        else:
            h_size = max(int(uvimg_size / aspect_ratio), 1)
            h_margin = int((uvimg_size - h_size) / 2)

            recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :]

        recon_Index_UV_roi_i = F.interpolate(
            recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest'
        )

        recon_Index_UV_roi.append(recon_Index_UV_roi_i)

    recon_Index_UV_roi = torch.cat(recon_Index_UV_roi, dim=0)

    return None, None, recon_Index_UV_roi, None