import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn

from functools import reduce
import operator
from bert.multimodal_bert import MultiModalBert

import torchvision
from lib import multimodal_segmentation_ppm

import transforms as T
import utils
import numpy as np

import torch.nn.functional as F

import gc
from collections import OrderedDict

import torch.backends.cudnn as cudnn

#from ffrecord.torch import DataLoader,Dataset
from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict

from mask2former_utils.criterion import SetCriterion, Criterion
from mask2former_utils.matcher import HungarianMatcher
from bert.modeling_bert import BertLMPredictionHead, BertEncoder




class WrapperModel(nn.Module):
    def __init__(self, image_model, language_model, classifier, args) :
        super(WrapperModel, self).__init__()
        self.image_model = image_model
        self.language_model = language_model
        self.classifier = classifier

        self.lang_proj = nn.Linear(768,256)

        config = Dict({
          "architectures": [
           "BertForMaskedLM"
          ],
          "attention_probs_dropout_prob": 0.1,
          "gradient_checkpointing": False,
          "hidden_act": "gelu",
          "hidden_dropout_prob": 0.1,
          "hidden_size": 512,
          "initializer_range": 0.02,
          "intermediate_size": 3072,
          "layer_norm_eps": 1e-12,
          #"max_position_embeddings": 16+20,
          "model_type": "bert",
          "num_attention_heads": 8,
          "num_hidden_layers": 8,
         "pad_token_id": 0,
          "position_embedding_type": "absolute",
          "transformers_version": "4.6.0.dev0",
          "type_vocab_size": 2,
          "use_cache": True,
          "vocab_size": 30522
        })
        self.mlm_transformer = BertEncoder(config)

        self.lang_proj = nn.Linear(768,256)
        self.mlm_vis_proj = nn.Conv2d(1024,512,1)
        self.mlm_lang_proj = nn.Linear(768,512)
        #print(vis_proj)
        self.mlm_head = BertLMPredictionHead(config)

        assert args.img_size % 4 == 0
        num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
        print(num_img_tokens)
        self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
        self.mlm_modal_embeds = nn.Embedding(3, 512)

        self.mlm_mask_embed = nn.Embedding(1, 512)
        self.mlm_pos_mlp = nn.Sequential(
            nn.Linear(2, 512),
            nn.LayerNorm(512),
            nn.Linear(512,512),
            nn.GELU()
        )

    def _get_binary_mask(self, target):
        # 返回每类的binary mask
        y, x = target.size()
        target_onehot = torch.zeros(self.num_classes + 1, y, x)
        target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
        return target_onehot[1:]

    def semantic_inference(self, mask_cls, mask_pred):       
        mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
        mask_pred = mask_pred.sigmoid()      
        semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)        
        return semseg

    def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position): 
        input_shape = image.shape[-2:]
        l_mask = attentions.unsqueeze(dim=-1)

        i0, Wh, Ww = self.image_model.forward_stem(image)
        l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions)

        i1 = self.image_model.forward_stage1(i0, Wh, Ww)
        l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
        i1_residual, H, W, i1_temp, Wh, Ww  = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
        l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) 
        i1 = i1_temp

        i2 = self.image_model.forward_stage2(i1, Wh, Ww)
        l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
        i2_residual, H, W, i2_temp, Wh, Ww  = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
        l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) 
        i2 = i2_temp

        i3 = self.image_model.forward_stage3(i2, Wh, Ww)
        l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
        i3_residual, H, W, i3_temp, Wh, Ww  = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
        l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) 
        i3 = i3_temp

        i4 = self.image_model.forward_stage4(i3, Wh, Ww)
        l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
        i4_residual, H, W, i4_temp, Wh, Ww  = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
        l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) 
        i4 = i4_temp

        #i1_residual, i2_residual, i3_residual, i4_residual = features
        #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
        #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
        outputs = {}
        outputs['s1'] = i1_residual
        outputs['s2'] = i2_residual
        outputs['s3'] = i3_residual
        outputs['s4'] = i4_residual

        predictions, mask_features = self.classifier(outputs)

        #print(target_reshape.shape)
        #tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3)
        #centroid = tmp.mean(1)
        #print(centroid)
        #centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
        #last_hidden_states = brt_model(sentences, attention_mask=attentions)[0]  # (6, 10, 768)
        #embedding = last_hidden_states.permute(0, 2, 1)  # (B, 768, N_l) to make Conv1d happy


        l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
        l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
        l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
        l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
        l4 = self.language_model.forward_stage4(l3, extended_attention_mask)


        mlp_embed = self.mlm_pos_mlp(position)
        #print(centroid_x, centroid_y)

        mlm_targets = torch.where(
            mlm_masks > 0,
            mlm_targets,
            torch.ones_like(mlm_targets) * (-1)
        )

        #print(x_c4[target_reshape[:, [0]].bool()].shape)
        vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1)
        #print(l4.shape)
        lang_features = self.mlm_lang_proj(l4)
        
        #print(lang_features.shape, vis_features.shape, mlp_embed.shape)
        mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1)
        #print(mm_features.shape)

        #print(mlm_modal_embeds.weight.shape)
        modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1)
        #print(modal_embeds.shape)

        #print(mlm_transformer)


        #print(attentions.shape)
        mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1)
        mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1)
        mixed_attention_mask = (1-mixed_attention_mask)* -10000.0
        head_mask = [None] * 8
        #extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device)
        #print(mm_features.shape, mixed_attention_mask.shape, head_mask)
        #print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape)
        head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0]
        #print(head_features.shape, attentions.shape)
        head_features = head_features[:, :20][attentions.bool()]
        
        #print(embedding.shape, mask_features.shape)
        mlm_predictions = self.mlm_head(head_features)
        mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size)
        mlm_targets = mlm_targets.squeeze(1)[attentions.bool()]
        #mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
        #loss += mlm_loss
        #mlm_loss_print=mlm_loss.item()

        return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets
# IoU calculation for validation
def IoU(pred, gt):
    #pred = pred.argmax(1)
    pred = (pred > 0.5)

    intersection = torch.sum(torch.mul(pred, gt))
    union = torch.sum(torch.add(pred, gt)) - intersection

    if intersection == 0 or union == 0:
        iou = 0
    else:
        iou = float(intersection) / float(union)

    return iou, intersection, union

def get_dataset(image_set, transform, args):
    from data.dataset_refer_bert_mlm import ReferDataset
    ds = ReferDataset(args,
                      split=image_set,
                      image_transforms=transform,
                      target_transforms=None
                      )
    num_classes = 2

    return ds, num_classes



def get_transform(args):
    transforms = [T.Resize(args.img_size, args.img_size),
                  T.ToTensor(),
                  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                  ]

    return T.Compose(transforms)


#def criterion(input, target):
#    weight = torch.FloatTensor([0.9, 1.1]).cuda()
#    return nn.functional.cross_entropy(input, target, weight=weight)


def evaluate(model, data_loader):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    total_its = 0
    acc_ious = 0

    # evaluation variables
    cum_I, cum_U = 0, 0
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0
    mean_IoU = []

    with torch.no_grad():
        for data in metric_logger.log_every(data_loader, 100, header):
            total_its += 1
            #image, target, sentences, attentions = data
            #image, target, sentences, attentions = image.cuda(non_blocking=True),\
            #                                       target.cuda(non_blocking=True),\
            #                                       sentences.cuda(non_blocking=True),\
            #                                       attentions.cuda(non_blocking=True)

            image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
            image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
                                                   target.cuda(non_blocking=True),\
                                                   sentences.cuda(non_blocking=True),\
                                                   attentions.cuda(non_blocking=True), \
                                                   mlm_targets.cuda(non_blocking=True), \
                                                   mlm_masks.cuda(non_blocking=True), \
                                                   position.cuda(non_blocking=True)

            sentences = sentences.squeeze(1)
            attentions = attentions.squeeze(1)
            #print("sentences", sentences.shape)
            #print("attentions", attentions.shape)


            output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
            mask_cls_results = output["pred_logits"]
            mask_pred_results = output["pred_masks"]

            target_shape = target.shape[-2:]
            mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)

            pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results)                
            output = pred_masks[0]


            iou, I, U = IoU(output, target)
            acc_ious += iou
            mean_IoU.append(iou)
            cum_I += I
            cum_U += U
            for n_eval_iou in range(len(eval_seg_iou_list)):
                eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
            seg_total += 1
        iou = acc_ious / total_its

    mean_IoU = np.array(mean_IoU)
    mIoU = np.mean(mean_IoU)
    print('Final results:')
    print('Mean IoU is %.2f\n' % (mIoU * 100.))
    results_str = ''
    for n_eval_iou in range(len(eval_seg_iou_list)):
        results_str += '    precision@%s = %.2f\n' % \
                       (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
    results_str += '    overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
    print(results_str)

    return 100 * iou, 100 * cum_I / cum_U


def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
                    iterations, args):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    header = 'Epoch: [{}]'.format(epoch)
    train_loss = 0
    total_its = 0

    for data in metric_logger.log_every(data_loader, print_freq, header):
        total_its += 1
        #image, target, sentences, attentions = data
        #image, target, sentences, attentions = image.cuda(non_blocking=True),\
        #                                       target.cuda(non_blocking=True),\
        #                                       sentences.cuda(non_blocking=True),\
        #                                       attentions.cuda(non_blocking=True)
        image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
        image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
                                               target.cuda(non_blocking=True),\
                                               sentences.cuda(non_blocking=True),\
                                               attentions.cuda(non_blocking=True), \
                                               mlm_targets.cuda(non_blocking=True), \
                                               mlm_masks.cuda(non_blocking=True), \
                                               position.cuda(non_blocking=True)

        sentences = sentences.squeeze(1)
        attentions = attentions.squeeze(1)
        #l_mask = attentions.unsqueeze(dim=-1)

        output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
        #print(avg_lang_feature.shape)
        avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1)
        #print("----")
        #print(output.shape)
        #print(mask_features.shape)
        #print(avg_lang_feature.shape)
        #print( mlm_predictions.shape)
        #print(mlm_targets.shape)
        #print("----")

        target_shape = target.shape[-2:]
        output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)

        if "aux_outputs" in output:
            for i, aux_outputs in enumerate(output["aux_outputs"]):
                output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)

        # pixel region
        B, C, H, W = mask_features.shape

        target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long()

        target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1)
        #print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape)

        #cl_loss = 0.0
        plic_lang_loss = 0.0
        plic_pos_loss = 0.0
        plic_neg_loss = 0.0
        for i in range(B):
            if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0):

                avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1)
                avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1)
                avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1)
                avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1)

                #avg lang feature no normalize???



                pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1)
                neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1)
                #inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1)
                #neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2)

                pos_features = torch.nn.functional.normalize(pos_features, dim=1)
                neg_features = torch.nn.functional.normalize(neg_features, dim=1)

                #print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape) 
                lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features)
                lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features)

                lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2)
                lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda()
                lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1)

                lang_score = torch.softmax(lang_matrix, -1)
                lang_score = 1.0 - lang_score[:, :, 0]

                pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features)
                pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features)

                pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2)
                pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda()
                pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1)

                pos_score = torch.softmax(pos_matrix, -1)
                pos_score = 1.0 - pos_score[:, :, 0]
                #pos_weight = pos_weight.view(-1, pos_weight.shape[-1])

                #intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1)
                neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features)
                neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features)

                neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2)
                neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda()
                neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1)

                neg_score = torch.softmax(neg_matrix, -1)
                neg_score = 1.0 - neg_score[:, :, 0]
                #neg_weight = neg_weight.view(-1, neg_weight.shape[-1])

                pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean()
                neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean()

                lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean()

                plic_pos_loss += pos_loss 
                plic_neg_loss += neg_loss 
                plic_lang_loss += lang_loss 
            #cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1)))
        plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B
        plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B
        plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B
        plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss


        #print(output.device, target.device)
        losses = criterion(output, target)
        weight_dict = criterion.weight_dict
                    
        loss_ce = 0.0
        loss_dice = 0.0
        loss_mask = 0.0
        for k in list(losses.keys()):
            if k in weight_dict:
                losses[k] *= criterion.weight_dict[k]
                if '_ce' in k:
                    loss_ce += losses[k]
                elif '_dice' in k:
                    loss_dice += losses[k]
                else:
                    loss_mask += losses[k]
            else:
                # remove this loss if not specified in `weight_dict`
                losses.pop(k)
        #loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask
        smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
        loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss


        #loss = criterion(output.squeeze(1), target.float())
        optimizer.zero_grad()  # set_to_none=True is only available in pytorch 1.6+
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        torch.cuda.synchronize()
        train_loss += loss.item()
        iterations += 1
        #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item())
        #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print)

        #del image, target, sentences, attentions, loss, output, data
        #if bert_model is not None:
        #    del last_hidden_states, embedding

        #gc.collect()
        #torch.cuda.empty_cache()
        #del loss
        #del cl_loss
        #del cl_lang_loss
        #del loss_ce
        #del loss_dice
        #del loss_mask
        torch.cuda.synchronize()


def main(args):
#def main(local_rank, args):
    #ip = os.environ['MASTER_IP']
    #port = os.environ['MASTER_PORT']
    #hosts = int(os.environ['WORLD_SIZE'])  # 机器个数 1
    #rank = int(os.environ['RANK'])  # 当前机器编号
    #gpus = torch.cuda.device_count()  # 每台机器的GPU个数
    #print(local_rank, rank, gpus) #3 0 8
    #dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
    #torch.cuda.set_device(local_rank)
    #dist.barrier()

    ##utils.init_distributed_mode(args)
    #args.distributed=True
    #args.gpu = local_rank
    #print(args)
    ##misc.init_distributed_mode(args)

    #print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    #print("{}".format(args).replace(', ', ',\n'))

    #device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    print('seed', seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    #cudnn.benchmark = True

    dataset, num_classes = get_dataset("train",
                                       get_transform(args=args),
                                       args=args)
    dataset_test, _ = get_dataset("val",
                                  get_transform(args=args),
                                  args=args)

    # batch sampler
    print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    #num_tasks = hosts*gpus
    #global_rank = rank*gpus+local_rank
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
                                                                    shuffle=True)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    # data loader
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)

    # model initialization
    print(args.model)
    model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights,
                                              args=args)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    #model.cuda()
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
    #single_model = model.module

    if args.model != 'lavt_one':
        model_class = MultiModalBert
        bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim)
        bert_model.pooler = None  # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
        #bert_model.cuda()
        bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
        #bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank])
        #single_bert_model = bert_model.module
    else:
        bert_model = None
        single_bert_model = None

    input_shape = dict()
    input_shape['s1'] = Dict({'channel': 128,  'stride': 4})
    input_shape['s2'] = Dict({'channel': 256,  'stride': 8})
    input_shape['s3'] = Dict({'channel': 512,  'stride': 16})
    input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})



    cfg = Dict()
    cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
    cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 
    cfg.MODEL.MASK_FORMER.NHEADS = 8
    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers
    cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]

    cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries
    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward
    cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers
    cfg.MODEL.MASK_FORMER.PRE_NORM = False

    cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
    cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight
    cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight
    cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight
    cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight

    cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points
    cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
    cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
    print(cfg)

    maskformer_head = MaskFormerHead(cfg, input_shape)
    maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
    #maskformer_head.cuda()
    #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
    #single_head = maskformer_head.module
    #print(single_head)

    model = WrapperModel(model.backbone, bert_model, maskformer_head, args)
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
    single_model = model.module

    # mask2former loss
    deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
    no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT

    # loss weights
    class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
    dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
    mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
    # self.criterion = Criterion(self.num_classes)

    # building criterion

    matcher = HungarianMatcher(
        cost_class=class_weight,
        cost_mask=mask_weight,
        cost_dice=dice_weight,
        num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
    )

    weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
    if deep_supervision:
        dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
        aux_weight_dict = {}
        for i in range(dec_layers - 1):
            aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    losses = ["labels", "masks"]
    criterion = SetCriterion(
        cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
        matcher=matcher,
        weight_dict=weight_dict,
        eos_coef=no_object_weight,
        losses=losses,
        num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
        oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
        importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
        device='cuda'
    )
    
    if args.resume == "auto":
        last_ckpt = ""
        for e in range(args.epochs):
            ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
            if os.path.exists(ckpt_path):
                last_ckpt = ckpt_path
        args.resume = last_ckpt

    # resume training
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        single_model.load_state_dict(checkpoint['model'])
        #if args.model != 'lavt_one':
        #    single_bert_model.load_state_dict(checkpoint['bert_model'])

    # parameters to optimize
    backbone_no_decay = list()
    backbone_decay = list()
    for name, m in single_model.image_model.named_parameters():
        if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
            backbone_no_decay.append(m)
        else:
            backbone_decay.append(m)

    params_to_optimize = [
        {'params': backbone_no_decay, 'weight_decay': 0.0},
        {'params': backbone_decay},
        {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
        # the following are the parameters of bert
        {"params": reduce(operator.concat,
                          [[p for p in single_model.language_model.encoder.layer[i].parameters()
                            if p.requires_grad] for i in range(10)])},
        {"params": single_model.language_model.pwams.parameters()},
        {"params": single_model.language_model.res_gates.parameters()},
        {"params": single_model.language_model.norms.parameters()},
        {"params": single_model.lang_proj.parameters()},
        #{"params": single_model.language_model.parameters()},
        {'params': single_model.mlm_head.parameters()},
        {'params': single_model.mlm_vis_proj.parameters()},
        {'params': single_model.mlm_lang_proj.parameters()},
        {'params': single_model.mlm_transformer.parameters()},
        {'params': single_model.mlm_pos_embeds.parameters()},
        {'params': single_model.mlm_modal_embeds.parameters()},
        {'params': single_model.mlm_mask_embed.parameters()},
        {'params': single_model.mlm_pos_mlp.parameters()},
        #{'params': mlm_head.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_transformer.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0},
        #{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0},
    ]


    # optimizer
    optimizer = torch.optim.AdamW(params_to_optimize,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay,
                                  amsgrad=args.amsgrad
                                  )

    # learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)

    # housekeeping
    start_time = time.time()
    iterations = 0
    best_oIoU = -0.1

    # resume training (optimizer, lr scheduler, and the epoch)
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        resume_epoch = checkpoint['epoch']
    else:
        resume_epoch = -999

    # training loops
    for epoch in range(max(0, resume_epoch+1), args.epochs):
        data_loader.sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
                        iterations, args)
        iou, overallIoU = evaluate(model, data_loader_test)

        print('Average object IoU {}'.format(iou))
        print('Overall IoU {}'.format(overallIoU))


        dict_to_save = {'model': single_model.state_dict(),
                        'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
                        'lr_scheduler': lr_scheduler.state_dict()}

        checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
        utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
        if utils.is_main_process():
            os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))

        if utils.is_main_process():
            ckpt_paths = []
            for e in range(args.epochs):
                ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
                print(ckpt_path)
                if os.path.exists(ckpt_path):
                    ckpt_paths.append(ckpt_path)
            print(ckpt_paths)
            for ckpt_path in ckpt_paths[:-args.max_ckpt]:
                os.remove(ckpt_path)
                print("remove {:s}".format(ckpt_path))


        save_checkpoint = (best_oIoU < overallIoU)
        if save_checkpoint:
            print('Better epoch: {}\n'.format(epoch))
            dict_to_save = {'model': single_model.state_dict(),
                            'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
                            'lr_scheduler': lr_scheduler.state_dict()}

            checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
            utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
            if utils.is_main_process():
                os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
            best_oIoU = overallIoU
        torch.cuda.empty_cache()

    # summarize
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == "__main__":
    from args import get_parser
    parser = get_parser()
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    # set up distributed learning
    utils.init_distributed_mode(args)
    print('Image size: {}'.format(str(args.img_size)))
    main(args)
    #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())