import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn

from bert.multimodal_bert import MultiModalBert
import torchvision

from lib import multimodal_segmentation_ppm
import transforms as T
import utils

import numpy as np
from PIL import Image
import torch.nn.functional as F

from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict
from bert.modeling_bert import BertLMPredictionHead, BertEncoder
import cv2
import textwrap

def get_dataset(image_set, transform, args):
    from data.dataset_refer_bert_vis import ReferDataset
    ds = ReferDataset(args,
                      split=image_set,
                      image_transforms=transform,
                      target_transforms=None,
                      eval_mode=True
                      )
    num_classes = 2
    return ds, num_classes


def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4):
    from scipy.ndimage.morphology import binary_dilation

    colors = np.reshape(colors, (-1, 3)) 
    colors = np.atleast_2d(colors) * cscale

    im_overlay = image.copy()
    object_ids = np.unique(mask)

    for object_id in object_ids[1:]:
        # Overlay color on  binary mask
        foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
        binary_mask = mask == object_id

        # Compose image
        im_overlay[binary_mask] = foreground[binary_mask]

        # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
        countours = binary_dilation(binary_mask) ^ binary_mask
        # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
        im_overlay[countours, :] = 0 

    return im_overlay.astype(image.dtype)

def evaluate(model, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")

    # evaluation variables
    cum_I, cum_U = 0, 0
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0
    mean_IoU = []
    header = 'Test:'

    with torch.no_grad():
        number = 0
        idx = 0
        for data in metric_logger.log_every(data_loader, 100, header):
            number +=1
            idx += 1
            print(number)
            image, target, sentences, attentions, raw_sentences, this_img, orig_img = data
            image, target, sentences, attentions = image.to(device), target.to(device), \
                                                   sentences.to(device), attentions.to(device)
            #if number <= 40:
            #    continue
                
            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)
            target = target.cpu().data.numpy()

            orig_shape = orig_img.shape
            orig_img = orig_img.numpy()[:, :, :, ::-1]
            print(orig_img.shape, "??")

            vis = np.zeros((480*2, 480*3,3)).astype(np.uint8)

            image_mean_iou = []
            
            for j in range(sentences.size(-1)):
                #if bert_model is not None:
                #    last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
                #    embedding = last_hidden_states.permute(0, 2, 1)
                #    output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
                #else:
                output = model(image, sentences[:, :, j], attentions[:, :, j])
                mask_cls_results = output["pred_logits"]
                mask_pred_results = output["pred_masks"]

                target_shape = target.shape[-2:]
                mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)

                pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)                
                output = pred_masks[0]

                output = output.cpu()


                
                #print(output.shape)
                #output_mask = output.argmax(1).data.numpy()
                output_mask = (output > 0.5).data.numpy()

                #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy()
                #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy()
                #soft
                #vis_output_mask = output.data.numpy()
                #vis_output_mask = output_mask

                #print(output.shape, orig_shape)
                orig_mask = torch.nn.functional.interpolate(pred_masks, (orig_shape[1], orig_shape[2]))
                #print(orig_mask.shape)
                orig_mask = (orig_mask > 0.5).data.cpu().numpy()
                ##orig_mask = orig_mask.argmax(1).data.numpy()

                print(orig_img[0].shape, orig_mask[0][0].shape, flush=True)
                print(orig_img.dtype, orig_mask.dtype)
                new = overlay_davis(orig_img[0], orig_mask[0][0].astype(np.uint8))
                #print(orig_mask.shape, orig_img.shape)
                #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8)
                #print("???", red_mask.shape, orig_mask.shape)
                #red_mask[:, :, :, 1] = orig_mask * 255
                #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8))

                #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0)
                #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?")
                #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None]
                #print(new.shape, orig_mask.shape, temp.shape, "check")
                ##print(vis_output_mask)
                ##output_mask = output.argmax(1).data.numpy()
                #
                #print(raw_sentences[j])
                
               # print(image.shape, target.shape, output_mask.shape)

                #mean = np.array([0.485, 0.456, 0.406])
                #std = np.array([0.229, 0.224, 0.225])
                #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1]
                #np_target = (target * 255).transpose(1,2,0).astype(np.uint8)
                ##print(output_mask)
                #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8)

                #font                   = cv2.FONT_HERSHEY_SIMPLEX
                #fontScale              = 0.75
                #fontColor              = (0,0,255)
                #thickness              = 1
                #lineType               = 2

                #wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35)
                #for k, line in enumerate(wrapped_text):
                #    bottomLeftCornerOfText = (10,420+k*20)
                #    np_output_mask = cv2.putText(np_output_mask, line, 
                #        bottomLeftCornerOfText, 
                #        font, 
                #        fontScale,
                #        fontColor,
                #        thickness,
                #        lineType)

                #
                #temp = j + 2
                #split = temp // 3
                #row = temp % 3
                #vis[0:480, 0:480, :] = np_image
                #vis[0:480, 480:960, :] = np_target.repeat(3, axis=2)
                #vis[split*480:(split+1)*480:, row * 480:(row+1)*480, :] = np_output_mask



                I, U = computeIoU(output_mask, target)
                if U == 0:
                    this_iou = 0.0
                else:
                    this_iou = I*1.0/U
                mean_IoU.append(this_iou)
                image_mean_iou.append(this_iou)
                cum_I += I
                cum_U += U
                for n_eval_iou in range(len(eval_seg_iou_list)):
                    eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                    seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
                seg_total += 1
                #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8))
                cv2.imwrite("vis/elia_refcoco+_green/{:s}_{:d}_{:d}_{:.2f}.jpg".format(this_img[0].split('.')[0], idx, j, this_iou), new.astype(np.uint8))

            print('---------------')
            #cv2.imshow("vis", vis)
            #cv2.waitKey(0)
            
            image_mean_iou = np.mean(np.array(image_mean_iou))
            print(image_mean_iou)
            #if image_mean_iou < 0.5:
            #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis)


            #del image, target, sentences, attentions, output, output_mask
            #if bert_model is not None:
            #    del last_hidden_states, embedding

    mean_IoU = np.array(mean_IoU)
    mIoU = np.mean(mean_IoU)
    print('Final results:')
    print('Mean IoU is %.2f\n' % (mIoU*100.))
    results_str = ''
    for n_eval_iou in range(len(eval_seg_iou_list)):
        results_str += '    precision@%s = %.2f\n' % \
                       (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
    results_str += '    overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
    print(results_str)

#def evaluate(model, data_loader, device):
#    model.eval()
#    metric_logger = utils.MetricLogger(delimiter="  ")
#
#    # evaluation variables
#    cum_I, cum_U = 0, 0
#    eval_seg_iou_list = [.5, .6, .7, .8, .9]
#    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
#    seg_total = 0
#    mean_IoU = []
#    header = 'Test:'
#
#    with torch.no_grad():
#        for data in metric_logger.log_every(data_loader, 100, header):
#            image, target, sentences, attentions = data
#            image, target, sentences, attentions = image.to(device), target.to(device), \
#                                                   sentences.to(device), attentions.to(device)
#            sentences = sentences.squeeze(1)
#            attentions = attentions.squeeze(1)
#            target = target.cpu().data.numpy()
#            for j in range(sentences.size(-1)):
#                #if bert_model is not None:
#                #    last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
#                #    embedding = last_hidden_states.permute(0, 2, 1)
#                #    output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
#                #else:
#                output = model(image, sentences[:, :, j], attentions[:, :, j])
#                mask_cls_results = output["pred_logits"]
#                mask_pred_results = output["pred_masks"]
#
#                target_shape = target.shape[-2:]
#                mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
#
#                pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)                
#                output = pred_masks[0]
#
#                output = output.cpu()
#                #print(output.shape)
#                #output_mask = output.argmax(1).data.numpy()
#                output_mask = (output > 0.5).data.numpy()
#                I, U = computeIoU(output_mask, target)
#                if U == 0:
#                    this_iou = 0.0
#                else:
#                    this_iou = I*1.0/U
#                mean_IoU.append(this_iou)
#                cum_I += I
#                cum_U += U
#                for n_eval_iou in range(len(eval_seg_iou_list)):
#                    eval_seg_iou = eval_seg_iou_list[n_eval_iou]
#                    seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
#                seg_total += 1
#
#            #del image, target, sentences, attentions, output, output_mask
#            #if bert_model is not None:
#            #    del last_hidden_states, embedding
#
#    mean_IoU = np.array(mean_IoU)
#    mIoU = np.mean(mean_IoU)
#    print('Final results:')
#    print('Mean IoU is %.2f\n' % (mIoU*100.))
#    results_str = ''
#    for n_eval_iou in range(len(eval_seg_iou_list)):
#        results_str += '    precision@%s = %.2f\n' % \
#                       (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
#    results_str += '    overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
#    print(results_str)


def get_transform(args):
    transforms = [T.Resize(args.img_size, args.img_size),
                  T.ToTensor(),
                  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                  ]

    return T.Compose(transforms)


def computeIoU(pred_seg, gd_seg):
    I = np.sum(np.logical_and(pred_seg, gd_seg))
    U = np.sum(np.logical_or(pred_seg, gd_seg))

    return I, U

class WrapperModel(nn.Module):
    def __init__(self, image_model, language_model, classifier, args) :
        super(WrapperModel, self).__init__()
        self.image_model = image_model
        self.language_model = language_model
        self.classifier = classifier
        self.lang_proj = nn.Linear(768,256)

        config = Dict({
          "architectures": [
           "BertForMaskedLM"
          ],
          "attention_probs_dropout_prob": 0.1,
          "gradient_checkpointing": False,
          "hidden_act": "gelu",
          "hidden_dropout_prob": 0.1,
          "hidden_size": 512,
          "initializer_range": 0.02,
          "intermediate_size": 3072,
          "layer_norm_eps": 1e-12,
          #"max_position_embeddings": 16+20,
          "model_type": "bert",
          "num_attention_heads": 8,
          "num_hidden_layers": 8,
         "pad_token_id": 0,
          "position_embedding_type": "absolute",
          "transformers_version": "4.6.0.dev0",
          "type_vocab_size": 2,
          "use_cache": True,
          "vocab_size": 30522
        })
        self.mlm_transformer = BertEncoder(config)

        self.lang_proj = nn.Linear(768,256)
        self.mlm_vis_proj = nn.Conv2d(1024,512,1)
        self.mlm_lang_proj = nn.Linear(768,512)
        #print(vis_proj)
        self.mlm_head = BertLMPredictionHead(config)

        assert args.img_size % 4 == 0
        num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
        print(num_img_tokens)
        self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
        self.mlm_modal_embeds = nn.Embedding(3, 512)

        self.mlm_mask_embed = nn.Embedding(1, 512)
        self.mlm_pos_mlp = nn.Sequential(
            nn.Linear(2, 512),
            nn.LayerNorm(512),
            nn.Linear(512,512),
            nn.GELU()
        )

    def _get_binary_mask(self, target):
        # 返回每类的binary mask
        y, x = target.size()
        target_onehot = torch.zeros(self.num_classes + 1, y, x)
        target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
        return target_onehot[1:]

    def semantic_inference(self, mask_cls, mask_pred):       
        mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
        mask_pred = mask_pred.sigmoid()      
        semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)        
        return semseg

    def forward(self, image, sentences, attentions): 
        input_shape = image.shape[-2:]
        l_mask = attentions.unsqueeze(dim=-1)

        i0, Wh, Ww = self.image_model.forward_stem(image)
        l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)

        i1 = self.image_model.forward_stage1(i0, Wh, Ww)
        l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
        i1_residual, H, W, i1_temp, Wh, Ww  = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
        l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) 
        i1 = i1_temp

        i2 = self.image_model.forward_stage2(i1, Wh, Ww)
        l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
        i2_residual, H, W, i2_temp, Wh, Ww  = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
        l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) 
        i2 = i2_temp

        i3 = self.image_model.forward_stage3(i2, Wh, Ww)
        l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
        i3_residual, H, W, i3_temp, Wh, Ww  = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
        l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) 
        i3 = i3_temp

        i4 = self.image_model.forward_stage4(i3, Wh, Ww)
        l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
        i4_residual, H, W, i4_temp, Wh, Ww  = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
        l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) 
        i4 = i4_temp

        #i1_residual, i2_residual, i3_residual, i4_residual = features
        #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
        #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
        outputs = {}
        outputs['s1'] = i1_residual
        outputs['s2'] = i2_residual
        outputs['s3'] = i3_residual
        outputs['s4'] = i4_residual

        predictions = self.classifier(outputs)
        return predictions

def main(args):
#def main(local_rank, args):

    #device = torch.device(args.device)
    device = 'cuda'
    dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
                                  sampler=test_sampler, num_workers=args.workers)
    print(args.model)
    single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
    #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
    #single_model.init_weights('./focalnet_base_lrf.pth')
    checkpoint = torch.load(args.resume, map_location='cpu')
    #single_model.load_state_dict(checkpoint['model'])
    #model = single_model.to(device)

    if args.model != 'lavt_one':
        model_class = MultiModalBert
        #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
        single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
        # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
        if args.ddp_trained_weights:
            single_bert_model.pooler = None
        #single_bert_model.load_state_dict(checkpoint['bert_model'])
        #bert_model = single_bert_model.to(device)
    else:
        bert_model = None

    #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
    #model.load_state_dict(checkpoint['model'])
    #model.to(device)
    input_shape = dict()
    input_shape['s1'] = Dict({'channel': 128,  'stride': 4})
    input_shape['s2'] = Dict({'channel': 256,  'stride': 8})
    input_shape['s3'] = Dict({'channel': 512,  'stride': 16})
    input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})



    cfg = Dict()
    cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
    cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 
    cfg.MODEL.MASK_FORMER.NHEADS = 8
    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
    cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]

    cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
    cfg.MODEL.MASK_FORMER.PRE_NORM = False


    maskformer_head = MaskFormerHead(cfg, input_shape)
    #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
    #maskformer_head.cuda()
    #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
    #single_head = maskformer_head.module
    #print(single_head)

    model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    #model.cuda()
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    #single_model = model.module
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    #single_model = model.module
    evaluate(model, data_loader_test, device=device)


if __name__ == "__main__":
    from args import get_parser
    parser = get_parser()
    args = parser.parse_args()
    print('Image size: {}'.format(str(args.img_size)))
    print(args)
    main(args)
    #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())