import torch
import torch.nn.functional as F

import numpy as np
from scipy.io import loadmat

def init_spixel_grid(args,  b_train=True, ratio = 1, downsize = 16):
    curr_img_height = args.crop_size
    curr_img_width = args.crop_size

    # pixel coord
    all_h_coords = np.arange(0, curr_img_height, 1)
    all_w_coords = np.arange(0, curr_img_width, 1)
    curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))

    coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])

    all_XY_feat = (torch.from_numpy(
        np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda())

    return  all_XY_feat

def label2one_hot_torch(labels, C=14):
    """ Converts an integer label torch.autograd.Variable to a one-hot Variable.

    Args:
      labels(tensor) : segmentation label
      C (integer) : number of classes in labels

    Returns:
      target (tensor) : one-hot vector of the input label

    Shape:
      labels: (B, 1, H, W)
      target: (B, N, H, W)
    """
    b,_, h, w = labels.shape
    one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
    target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type

    return target.type(torch.float32)

colors = loadmat('data/color150.mat')['colors']
colors = np.concatenate((colors, colors, colors, colors))

def unique(ar, return_index=False, return_inverse=False, return_counts=False):
    ar = np.asanyarray(ar).flatten()

    optional_indices = return_index or return_inverse
    optional_returns = optional_indices or return_counts

    if ar.size == 0:
        if not optional_returns:
            ret = ar
        else:
            ret = (ar,)
            if return_index:
                ret += (np.empty(0, np.bool),)
            if return_inverse:
                ret += (np.empty(0, np.bool),)
            if return_counts:
                ret += (np.empty(0, np.intp),)
        return ret
    if optional_indices:
        perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
        aux = ar[perm]
    else:
        ar.sort()
        aux = ar
    flag = np.concatenate(([True], aux[1:] != aux[:-1]))

    if not optional_returns:
        ret = aux[flag]
    else:
        ret = (aux[flag],)
        if return_index:
            ret += (perm[flag],)
        if return_inverse:
            iflag = np.cumsum(flag) - 1
            inv_idx = np.empty(ar.shape, dtype=np.intp)
            inv_idx[perm] = iflag
            ret += (inv_idx,)
        if return_counts:
            idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
            ret += (np.diff(idx),)
    return ret

def colorEncode(labelmap, mode='RGB'):
    labelmap = labelmap.astype('int')
    labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
                            dtype=np.uint8)
    for label in unique(labelmap):
        if label < 0:
            continue
        labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
            np.tile(colors[label],
                    (labelmap.shape[0], labelmap.shape[1], 1))

    if mode == 'BGR':
        return labelmap_rgb[:, :, ::-1]
    else:
        return labelmap_rgb

def get_edges(sp_label, sp_num):
    # This function returns a (hw) * (hw) matrix N.
    # If Nij = 1, then superpixel i and j are neighbors
    # Otherwise Nij = 0.
    top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :]
    left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:]
    top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:]
    top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1]
    n_affs = []
    edge_indices = []
    for i in range(sp_label.shape[0]):
        # change to torch.ones below to include self-loop in graph
        n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda()
        # top/bottom
        top_i = top[i].squeeze()
        x, y = torch.nonzero(top_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # left/right
        left_i = left[i].squeeze()
        try:
            x, y = torch.nonzero(left_i, as_tuple = True)
        except:
            import pdb; pdb.set_trace()
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top left
        top_left_i = top_left[i].squeeze()
        x, y = torch.nonzero(top_left_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top right
        top_right_i = top_right[i].squeeze()
        x, y = torch.nonzero(top_right_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y+1].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        n_affs.append(n_aff)
        edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True))
        edge_indices.append(edge_index.cuda())
    return edge_indices


def draw_color_seg(seg):
    seg = seg.detach().cpu().numpy()
    color_ = []
    for i in range(seg.shape[0]):
        colori = colorEncode(seg[i].squeeze())
        colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
        color_.append(colori)
    color_ = torch.stack(color_)
    return color_