import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn

from bert.multimodal_bert import MultiModalBert
import torchvision

from lib import multimodal_segmentation_ppm
import transforms as T
import utils

import numpy as np
from PIL import Image
import torch.nn.functional as F

from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict
from bert.modeling_bert import BertLMPredictionHead, BertEncoder

def get_dataset(image_set, transform, args):
    from data.dataset_refer_bert import ReferDataset
    ds = ReferDataset(args,
                      split=image_set,
                      image_transforms=transform,
                      target_transforms=None,
                      eval_mode=True
                      )
    num_classes = 2
    return ds, num_classes


def evaluate(model, data_loader, device):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")

    # evaluation variables
    cum_I, cum_U = 0, 0
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0
    mean_IoU = []
    header = 'Test:'

    with torch.no_grad():
        for data in metric_logger.log_every(data_loader, 100, header):
            image, target, sentences, attentions = data
            image, target, sentences, attentions = image.to(device), target.to(device), \
                                                   sentences.to(device), attentions.to(device)
            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)
            target = target.cpu().data.numpy()
            for j in range(sentences.size(-1)):
                #if bert_model is not None:
                #    last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
                #    embedding = last_hidden_states.permute(0, 2, 1)
                #    output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
                #else:
                output = model(image, sentences[:, :, j], attentions[:, :, j])
                mask_cls_results = output["pred_logits"]
                mask_pred_results = output["pred_masks"]

                target_shape = target.shape[-2:]
                mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)

                pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)                
                output = pred_masks[0]

                output = output.cpu()
                #print(output.shape)
                #output_mask = output.argmax(1).data.numpy()
                output_mask = (output > 0.5).data.numpy()
                I, U = computeIoU(output_mask, target)
                if U == 0:
                    this_iou = 0.0
                else:
                    this_iou = I*1.0/U
                mean_IoU.append(this_iou)
                cum_I += I
                cum_U += U
                for n_eval_iou in range(len(eval_seg_iou_list)):
                    eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                    seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
                seg_total += 1

            #del image, target, sentences, attentions, output, output_mask
            #if bert_model is not None:
            #    del last_hidden_states, embedding

    mean_IoU = np.array(mean_IoU)
    mIoU = np.mean(mean_IoU)
    print('Final results:')
    print('Mean IoU is %.2f\n' % (mIoU*100.))
    results_str = ''
    for n_eval_iou in range(len(eval_seg_iou_list)):
        results_str += '    precision@%s = %.2f\n' % \
                       (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
    results_str += '    overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
    print(results_str)


def get_transform(args):
    transforms = [T.Resize(args.img_size, args.img_size),
                  T.ToTensor(),
                  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                  ]

    return T.Compose(transforms)


def computeIoU(pred_seg, gd_seg):
    I = np.sum(np.logical_and(pred_seg, gd_seg))
    U = np.sum(np.logical_or(pred_seg, gd_seg))

    return I, U

class WrapperModel(nn.Module):
    def __init__(self, image_model, language_model, classifier, args) :
        super(WrapperModel, self).__init__()
        self.image_model = image_model
        self.language_model = language_model
        self.classifier = classifier
        self.lang_proj = nn.Linear(768,256)

        config = Dict({
          "architectures": [
           "BertForMaskedLM"
          ],
          "attention_probs_dropout_prob": 0.1,
          "gradient_checkpointing": False,
          "hidden_act": "gelu",
          "hidden_dropout_prob": 0.1,
          "hidden_size": 512,
          "initializer_range": 0.02,
          "intermediate_size": 3072,
          "layer_norm_eps": 1e-12,
          #"max_position_embeddings": 16+20,
          "model_type": "bert",
          "num_attention_heads": 8,
          "num_hidden_layers": 8,
         "pad_token_id": 0,
          "position_embedding_type": "absolute",
          "transformers_version": "4.6.0.dev0",
          "type_vocab_size": 2,
          "use_cache": True,
          "vocab_size": 30522
        })
        self.mlm_transformer = BertEncoder(config)

        self.lang_proj = nn.Linear(768,256)
        self.mlm_vis_proj = nn.Conv2d(1024,512,1)
        self.mlm_lang_proj = nn.Linear(768,512)
        #print(vis_proj)
        self.mlm_head = BertLMPredictionHead(config)

        assert args.img_size % 4 == 0
        num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
        print(num_img_tokens)
        self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
        self.mlm_modal_embeds = nn.Embedding(3, 512)

        self.mlm_mask_embed = nn.Embedding(1, 512)
        self.mlm_pos_mlp = nn.Sequential(
            nn.Linear(2, 512),
            nn.LayerNorm(512),
            nn.Linear(512,512),
            nn.GELU()
        )

    def _get_binary_mask(self, target):
        # 返回每类的binary mask
        y, x = target.size()
        target_onehot = torch.zeros(self.num_classes + 1, y, x)
        target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
        return target_onehot[1:]

    def semantic_inference(self, mask_cls, mask_pred):       
        mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
        mask_pred = mask_pred.sigmoid()      
        semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)        
        return semseg

    def forward(self, image, sentences, attentions): 
        input_shape = image.shape[-2:]
        l_mask = attentions.unsqueeze(dim=-1)

        i0, Wh, Ww = self.image_model.forward_stem(image)
        l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)

        i1 = self.image_model.forward_stage1(i0, Wh, Ww)
        l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
        i1_residual, H, W, i1_temp, Wh, Ww  = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
        l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) 
        i1 = i1_temp

        i2 = self.image_model.forward_stage2(i1, Wh, Ww)
        l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
        i2_residual, H, W, i2_temp, Wh, Ww  = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
        l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) 
        i2 = i2_temp

        i3 = self.image_model.forward_stage3(i2, Wh, Ww)
        l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
        i3_residual, H, W, i3_temp, Wh, Ww  = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
        l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) 
        i3 = i3_temp

        i4 = self.image_model.forward_stage4(i3, Wh, Ww)
        l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
        i4_residual, H, W, i4_temp, Wh, Ww  = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
        l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) 
        i4 = i4_temp

        #i1_residual, i2_residual, i3_residual, i4_residual = features
        #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
        #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
        outputs = {}
        outputs['s1'] = i1_residual
        outputs['s2'] = i2_residual
        outputs['s3'] = i3_residual
        outputs['s4'] = i4_residual

        predictions, _ = self.classifier(outputs)
        return predictions

def main(args):
#def main(local_rank, args):

    #device = torch.device(args.device)
    device = 'cuda'
    dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
                                  sampler=test_sampler, num_workers=args.workers)
    print(args.model)
    single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
    #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
    #single_model.init_weights('./focalnet_base_lrf.pth')
    checkpoint = torch.load(args.resume, map_location='cpu')
    #single_model.load_state_dict(checkpoint['model'])
    #model = single_model.to(device)

    if args.model != 'lavt_one':
        model_class = MultiModalBert
        #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
        single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
        # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
        if args.ddp_trained_weights:
            single_bert_model.pooler = None
        #single_bert_model.load_state_dict(checkpoint['bert_model'])
        #bert_model = single_bert_model.to(device)
    else:
        bert_model = None

    #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
    #model.load_state_dict(checkpoint['model'])
    #model.to(device)
    input_shape = dict()
    input_shape['s1'] = Dict({'channel': 128,  'stride': 4})
    input_shape['s2'] = Dict({'channel': 256,  'stride': 8})
    input_shape['s3'] = Dict({'channel': 512,  'stride': 16})
    input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})



    cfg = Dict()
    cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
    cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 
    cfg.MODEL.MASK_FORMER.NHEADS = 8
    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
    cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]

    cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
    cfg.MODEL.MASK_FORMER.PRE_NORM = False


    maskformer_head = MaskFormerHead(cfg, input_shape)
    #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
    #maskformer_head.cuda()
    #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
    #single_head = maskformer_head.module
    #print(single_head)

    model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    #model.cuda()
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    #single_model = model.module
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    #single_model = model.module
    evaluate(model, data_loader_test, device=device)


if __name__ == "__main__":
    from args import get_parser
    parser = get_parser()
    args = parser.parse_args()
    print('Image size: {}'.format(str(args.img_size)))
    print(args)
    main(args)
    #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())