diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..bd764a800a9d80da448fe912b9e8263364fdc229
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,131 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/\
+
+flagged/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b076d86084a9743afbd07dac765b7fdabb8e064f
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2022, Aastha Singh
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+   list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+   this list of conditions and the following disclaimer in the documentation
+   and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+   contributors may be used to endorse or promote products derived from
+   this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
index 91e7a1244e05851f5bc1073302d94ebb97f2321d..7fd235c0f0d7f48e345b3fb1c9eae7903a67cdd3 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,94 @@
----
-title: GLIP BLIP Object Detection VQA
-emoji: 📊
-colorFrom: indigo
-colorTo: pink
-sdk: gradio
-sdk_version: 3.4.1
-app_file: app.py
-pinned: false
-license: bsd-3-clause
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Vision-Language Object Detection and Visual Question Answering
+This repository includes Microsoft's GLIP and Salesforce's BLIP ensembled demo for detecting objects and Visual Question Answering based on text prompts.  
+
+<br />
+
+## About GLIP: Grounded Language-Image Pre-training - 
+> GLIP demonstrate strong zero-shot and few-shot transferability to various object-level recognition tasks.
+
+> The model used in this repo is GLIP-T, it is originally pre-trained on Conceptual Captions 3M and SBU captions.
+
+<br />
+
+## About BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation - 
+
+> A new model architecture that enables a wider range of downstream tasks than existing methods, and a new dataset bootstrapping method for learning from noisy web data.
+
+<br />
+
+## Installation and Setup
+
+***Enviornment*** - Due to limitations with `maskrcnn_benchmark`, this repo requires Pytorch=1.10 and torchvision.
+
+Use `requirements.txt` to install dependencies
+
+```sh
+pip3 install -r requirements.txt
+```
+Build `maskrcnn_benchmark`
+```
+python setup.py build develop --user
+```
+
+To verify a successful build, check the terminal for message  
+"Finished processing dependencies for maskrcnn-benchmark==0.1"
+
+## Checkpoints
+
+> Download the pre-trained models into the `checkpoints` folder.
+
+<br />
+
+```sh
+mkdir checkpoints
+cd checkpoints
+```
+
+Model | Weight
+-- | --
+**GLIP-T** | [weight](https://drive.google.com/file/d/1nlPL6PHkslarP6RiWJJu6QGKjqHG4tkc/view?usp=sharing)
+**BLIP** | [weight](https://drive.google.com/file/d/1QliNGiAcyCCJLd22eNOxWvMUDzb7GzrO/view?usp=sharing)
+
+<br />files.maxMemoryForLargeFilesMB
+
+## If you have an NVIDIA GPU with 8GB VRAM, run local demo using Gradio interface
+
+```sh
+python3 app.py
+```
+## Future Work
+
+- [x] Frame based Visual Question Answering
+- [ ] Each object based Visual Question Answering
+
+
+## Citations
+
+```txt
+@inproceedings{li2022blip,
+      title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation}, 
+      author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
+      year={2022},
+      booktitle={ICML},
+}
+@inproceedings{li2021grounded,
+      title={Grounded Language-Image Pre-training},
+      author={Liunian Harold Li* and Pengchuan Zhang* and Haotian Zhang* and Jianwei Yang and Chunyuan Li and Yiwu Zhong and Lijuan Wang and Lu Yuan and Lei Zhang and Jenq-Neng Hwang and Kai-Wei Chang and Jianfeng Gao},
+      year={2022},
+      booktitle={CVPR},
+}
+@article{zhang2022glipv2,
+  title={GLIPv2: Unifying Localization and Vision-Language Understanding},
+  author={Zhang, Haotian* and Zhang, Pengchuan* and Hu, Xiaowei and Chen, Yen-Chun and Li, Liunian Harold and Dai, Xiyang and Wang, Lijuan and Yuan, Lu and Hwang, Jenq-Neng and Gao, Jianfeng},
+  journal={arXiv preprint arXiv:2206.05836},
+  year={2022}
+}
+@article{li2022elevater,
+  title={ELEVATER: A Benchmark and Toolkit for Evaluating Language-Augmented Visual Models},
+  author={Li*, Chunyuan and Liu*, Haotian and Li, Liunian Harold and Zhang, Pengchuan and Aneja, Jyoti and Yang, Jianwei and Jin, Ping and Lee, Yong Jae and Hu, Houdong and Liu, Zicheng and others},
+  journal={arXiv preprint arXiv:2204.08790},
+  year={2022}
+}
+```
+## Acknowledgement
+The implementation of this work relies on resources from <a href="https://github.com/salesforce/BLIP">BLIP</a>, <a href="https://github.com/microsoft/GLIP">GLIP</a>,  <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6a03c5bcd0e21ebcacac2453c128052f9deac0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,57 @@
+import os
+import gradio as gr
+import warnings
+
+warnings.filterwarnings("ignore")
+
+os.system("python setup.py build develop --user")
+
+from maskrcnn_benchmark.config import cfg
+from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo
+import vqa
+import vqa
+
+# Use this command for evaluate the GLIP-T model
+config_file = "configs/glip_Swin_T_O365_GoldG.yaml"
+weight_file = "checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth"
+
+# manual override some options
+cfg.local_rank = 0
+cfg.num_gpus = 1
+cfg.merge_from_file(config_file)
+cfg.merge_from_list(["MODEL.WEIGHT", weight_file])
+cfg.merge_from_list(["MODEL.DEVICE", "cuda"])
+
+glip_demo = GLIPDemo(
+    cfg,
+    min_image_size=800,
+    confidence_threshold=0.7,
+    show_mask_heatmaps=False
+)
+blip_demo = vqa.VQA(
+    model_path = 'checkpoints/model_base_vqa_capfilt_large.pth'
+)
+
+def predict(image, object, question):
+    result, _ = glip_demo.run_on_web_image(image[:, :, [2, 1, 0]], object, 0.5)
+    answer = blip_demo.vqa_demo(image, question)
+    return result[:, :, [2, 1, 0]], answer
+
+image = gr.inputs.Image()
+
+gr.Interface(
+    description="GLIP + BLIP VQA Demo.",
+    fn=predict,
+    inputs=[
+        "image", 
+        gr.Textbox(label='Objects', lines=1, placeholder="Objects here.."), 
+        gr.Textbox(label='Question', lines=1, placeholder="Question here..")],
+
+    outputs=[
+        gr.outputs.Image(
+            type="pil",
+            label="grounding results"
+        ),
+        gr.Textbox(label="Answer")
+    ],
+).launch()
\ No newline at end of file
diff --git a/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth b/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d05b8d5d3318107871c13ca068ee094644600779
--- /dev/null
+++ b/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bec0a3dea804fcb278d7106c5438de5116ee888e49dfae46270e7ad7bc4ccbf
+size 3710104213
diff --git a/checkpoints/model_base_vqa_capfilt_large.pth b/checkpoints/model_base_vqa_capfilt_large.pth
new file mode 100644
index 0000000000000000000000000000000000000000..df8c62ad684ab84409a19a947cd33b920b78b5ad
--- /dev/null
+++ b/checkpoints/model_base_vqa_capfilt_large.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a7d546209f1ccfa8b3cd3a0138c53e0d1e95e4a4bc280bef8f67e20fe4925ae
+size 1446244375
diff --git a/configs/glip_Swin_T_O365_GoldG.yaml b/configs/glip_Swin_T_O365_GoldG.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..80b9edba1b47a83f5da99254dd081dac3f80354a
--- /dev/null
+++ b/configs/glip_Swin_T_O365_GoldG.yaml
@@ -0,0 +1,100 @@
+MODEL:
+  META_ARCHITECTURE: "GeneralizedVLRCNN"
+  WEIGHT: "swin_tiny_patch4_window7_224.pth"
+  RPN_ONLY: True
+  RPN_ARCHITECTURE: "VLDYHEAD"
+
+  BACKBONE:
+    CONV_BODY: "SWINT-FPN-RETINANET"
+    OUT_CHANNELS: 256
+    FREEZE_CONV_BODY_AT: -1
+
+  LANGUAGE_BACKBONE:
+    FREEZE: False
+    MODEL_TYPE: "bert-base-uncased" # "roberta-base", "clip"
+    MASK_SPECIAL: False
+
+  RPN:
+    USE_FPN: True
+    ANCHOR_SIZES: (64, 128, 256, 512, 1024)
+    ANCHOR_STRIDE: (8, 16, 32, 64, 128)
+    ASPECT_RATIOS: (1.0,)
+    SCALES_PER_OCTAVE: 1
+
+  DYHEAD:
+    CHANNELS: 256
+    NUM_CONVS: 6
+    USE_GN: True
+    USE_DYRELU: True
+    USE_DFCONV: True
+    USE_DYFUSE: True
+    TOPK: 9 # topk for selecting candidate positive samples from each level
+    SCORE_AGG: "MEAN"
+    LOG_SCALE: 0.0
+
+    FUSE_CONFIG:
+      EARLY_FUSE_ON: True
+      TYPE: "MHA-B"
+      USE_CLASSIFICATION_LOSS: False
+      USE_TOKEN_LOSS: False
+      USE_CONTRASTIVE_ALIGN_LOSS: False
+      CONTRASTIVE_HIDDEN_DIM: 64
+      USE_DOT_PRODUCT_TOKEN_LOSS: True
+      USE_FUSED_FEATURES_DOT_PRODUCT: True
+      USE_LAYER_SCALE: True
+      CLAMP_MIN_FOR_UNDERFLOW: True
+      CLAMP_MAX_FOR_OVERFLOW: True
+      CLAMP_BERTATTN_MIN_FOR_UNDERFLOW: True
+      CLAMP_BERTATTN_MAX_FOR_OVERFLOW: True
+      CLAMP_DOT_PRODUCT: True
+           
+    USE_CHECKPOINT: True
+
+TEST:
+  DURING_TRAINING: False
+  IMS_PER_BATCH: 64
+
+# use for grounding model
+DATASETS:
+  TRAIN: ("object365_dt_train", "mixed_train_no_coco", "flickr30k_train", )
+  TEST: ("coco_2017_val", )
+  DISABLE_SHUFFLE: False
+  ADD_DET_PROMPT: False
+  RANDOM_SAMPLE_NEG: 85
+  CONTROL_PROB: (0.0, 0.0, 0.5, 0.0)
+
+  SEPARATION_TOKENS: ". "
+
+INPUT:
+  PIXEL_MEAN: [ 103.530, 116.280, 123.675 ]
+  PIXEL_STD: [ 57.375, 57.120, 58.395 ]
+  MIN_SIZE_TRAIN: 800
+  MAX_SIZE_TRAIN: 1333
+  MIN_SIZE_TEST: 800
+  MAX_SIZE_TEST: 1333
+
+AUGMENT:
+  MULT_MIN_SIZE_TRAIN: (480,560,640,720,800)
+
+DATALOADER:
+  SIZE_DIVISIBILITY: 32
+
+SOLVER:
+  OPTIMIZER: ADAMW
+  BASE_LR: 0.0001
+  LANG_LR: 0.00001
+  WEIGHT_DECAY: 0.0001
+  STEPS: (0.67, 0.89)
+  MAX_EPOCH: 30
+  IMS_PER_BATCH: 64
+  WARMUP_ITERS: 2000
+  WARMUP_FACTOR: 0.001
+  USE_AMP: True
+  MODEL_EMA: 0.999
+  FIND_UNUSED_PARAMETERS: False
+
+  CLIP_GRADIENTS:
+    ENABLED: True
+    CLIP_TYPE: "full_model"
+    CLIP_VALUE: 1.0
+    NORM_TYPE: 2.0
\ No newline at end of file
diff --git a/configs/med_config.json b/configs/med_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0ffad0a6f3c2f9f11b8faa84529d9860bb70327a
--- /dev/null
+++ b/configs/med_config.json
@@ -0,0 +1,21 @@
+{
+  "architectures": [
+    "BertModel"
+  ],
+  "attention_probs_dropout_prob": 0.1,
+  "hidden_act": "gelu",
+  "hidden_dropout_prob": 0.1,
+  "hidden_size": 768,
+  "initializer_range": 0.02,
+  "intermediate_size": 3072,
+  "layer_norm_eps": 1e-12,
+  "max_position_embeddings": 512,
+  "model_type": "bert",
+  "num_attention_heads": 12,
+  "num_hidden_layers": 12,
+  "pad_token_id": 0,
+  "type_vocab_size": 2,
+  "vocab_size": 30524,
+  "encoder_width": 768,
+  "add_cross_attention": true   
+}
diff --git a/configs/vqa.yaml b/configs/vqa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..74327e6d0a34672023b44569558fe8beeb052548
--- /dev/null
+++ b/configs/vqa.yaml
@@ -0,0 +1,25 @@
+vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
+vg_root: '/export/share/datasets/vision/visual-genome/'  #followed by image/
+train_files: ['vqa_train','vqa_val','vg_qa']
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
+
+# size of vit model; base or large
+vit: 'base'
+batch_size_train: 16 
+batch_size_test: 32 
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+init_lr: 2e-5
+
+image_size: 480
+
+k_test: 128
+inference: 'rank'
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 10
\ No newline at end of file
diff --git a/itm.py b/itm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da8af6dfe782beff41de4efb952f481fa97a6c6
--- /dev/null
+++ b/itm.py
@@ -0,0 +1,77 @@
+import sys
+from PIL import Image
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from models.blip_vqa import blip_vqa
+from models.blip_itm import blip_itm
+
+
+class VQA:
+    def __init__(self, model_path, image_size=480):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base')
+        self.model.eval()
+        self.model = self.model.to(self.device)
+
+    def load_demo_image(self, image_size, img_path, device):
+        raw_image = Image.open(img_path).convert('RGB')   
+        w,h = raw_image.size
+        transform = transforms.Compose([
+            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
+            transforms.ToTensor(),
+            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+            ]) 
+        image = transform(raw_image).unsqueeze(0).to(device)   
+        return raw_image, image
+
+    def vqa(self, img_path, question):
+        raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device)        
+        with torch.no_grad():
+            answer = self.model(image, question, train=False, inference='generate')
+            return answer[0]
+class ITM:
+    def __init__(self, model_path, image_size=384):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.model = blip_itm(pretrained=model_path, image_size=image_size, vit='base')
+        self.model.eval()
+        self.model = self.model.to(device='cpu')
+    
+    def load_demo_image(self, image_size, img_path, device):
+        raw_image = Image.open(img_path).convert('RGB')   
+        w,h = raw_image.size
+        transform = transforms.Compose([
+            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
+            transforms.ToTensor(),
+            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+            ]) 
+        image = transform(raw_image).unsqueeze(0).to(device)   
+        return raw_image, image
+
+    def itm(self, img_path, caption):
+        raw_image, image = self.load_demo_image(image_size=384,img_path=img_path, device=self.device)
+        itm_output = self.model(image,caption,match_head='itm')
+        itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
+        itc_score = self.model(image,caption,match_head='itc')
+        # print('The image and text is matched with a probability of %.4f'%itm_score)
+        # print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)
+        return itm_score, itc_score
+
+if __name__=="__main__":
+    if not len(sys.argv) == 3:
+        print('Format: python3 vqa.py <path_to_img> <question>')
+        print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"')
+        
+    else:
+        model_path = 'checkpoints/model_base_vqa_capfilt_large.pth'
+        model2_path = 'model_base_retrieval_coco.pth'
+        # vqa_object = VQA(model_path=model_path)
+        itm_object = ITM(model_path=model2_path)
+        img_path = sys.argv[1]
+        # question = sys.argv[2]
+        caption = sys.argv[2]
+        # answer = vqa_object.vqa(img_path, caption)
+        itm_score, itc_score = itm_object.itm(img_path, caption)
+        # print('Question: {} | Answer: {}'.format(caption, answer))
+        print('Caption: {} | The image and text is matched with a probability of %.4f: {} | The image feature and text feature has a cosine similarity of %.4f: {}'.format (caption,itm_score,itc_score))
+
diff --git a/maskrcnn_benchmark/__init__.py b/maskrcnn_benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bc96c7a6bf8379e1adfb3e4adf536107b385fa9
--- /dev/null
+++ b/maskrcnn_benchmark/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/maskrcnn_benchmark/config/__init__.py b/maskrcnn_benchmark/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2015d6bd830bc3e0ec8b1ca7fcb63b4781a41ad
--- /dev/null
+++ b/maskrcnn_benchmark/config/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .defaults import _C as cfg
+from .paths_catalog import try_to_find
\ No newline at end of file
diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd62a9ea307b727e0db06985264707046e8c7234
--- /dev/null
+++ b/maskrcnn_benchmark/config/defaults.py
@@ -0,0 +1,861 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
+
+from yacs.config import CfgNode as CN
+
+# -----------------------------------------------------------------------------
+# Convention about Training / Test specific parameters
+# -----------------------------------------------------------------------------
+# Whenever an argument can be either used for training or for testing, the
+# corresponding name will be post-fixed by a _TRAIN for a training parameter,
+# or _TEST for a test-specific parameter.
+# For example, the number of images during training will be
+# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
+# IMAGES_PER_BATCH_TEST
+
+# -----------------------------------------------------------------------------
+# Config definition
+# -----------------------------------------------------------------------------
+
+_C = CN()
+
+_C.MODEL = CN()
+_C.MODEL.RPN_ONLY = False
+_C.MODEL.BOX_ON = True
+_C.MODEL.MASK_ON = False
+_C.MODEL.KEYPOINT_ON = False
+_C.MODEL.DEVICE = "cuda"
+
+_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
+
+_C.MODEL.RPN_ARCHITECTURE = "RPN"
+_C.MODEL.DEBUG = False  # add debug flag
+_C.MODEL.ONNX = False  # add onnx flag
+
+# If the WEIGHT starts with a catalog://, like :R-50, the code will look for
+# the path in paths_catalog. Else, it will use it as the specified absolute
+# path
+_C.MODEL.WEIGHT = ""
+_C.MODEL.PRETRAIN_NAME = ""
+
+# If LINEAR_PROB = True, only the last linear layers in rpn and roi_head are trainable
+_C.MODEL.LINEAR_PROB = False
+
+# -----------------------------------------------------------------------------
+# Multitask Training / Test specific parameters
+# -----------------------------------------------------------------------------
+_C.MODEL.MULTITASK = CN(new_allowed=True)
+
+# -----------------------------------------------------------------------------
+# INPUT
+# -----------------------------------------------------------------------------
+_C.INPUT = CN()
+# Size of the smallest side of the image during training
+_C.INPUT.MIN_SIZE_TRAIN = 800  # (800,)
+# Maximum size of the side of the image during training
+_C.INPUT.MAX_SIZE_TRAIN = 1333
+# Size of the smallest side of the image during testing
+_C.INPUT.MIN_SIZE_TEST = 800
+# Maximum size of the side of the image during testing
+_C.INPUT.MAX_SIZE_TEST = 1333
+# Values to be used for image normalization
+_C.INPUT.PIXEL_MEAN = [102.9801, 115.9465, 122.7717]
+# Values to be used for image normalization
+_C.INPUT.PIXEL_STD = [1., 1., 1.]
+# Convert image to BGR format (for Caffe2 models), in range 0-255
+_C.INPUT.TO_BGR255 = True
+_C.INPUT.FORMAT = ''
+_C.INPUT.FIX_RES = False
+
+# -----------------------------------------------------------------------------
+# Augmentation
+# -----------------------------------------------------------------------------
+_C.AUGMENT = CN()
+_C.AUGMENT.USE_RA = 0
+_C.AUGMENT.FLIP_PROB_TRAIN = 0.5
+_C.AUGMENT.VERTICAL_FLIP_PROB_TRAIN = 0.0
+_C.AUGMENT.MULT_MIN_SIZE_TRAIN = ()
+
+_C.AUGMENT.BRIGHTNESS = 0.0
+_C.AUGMENT.CONTRAST = 0.0
+_C.AUGMENT.SATURATION = 0.0
+_C.AUGMENT.HUE = 0.0
+
+_C.AUGMENT.CROP_PROB = 0.5
+_C.AUGMENT.CROP_MIN_IOUS = (0.1, 0.3, 0.5, 0.7, 0.9)
+_C.AUGMENT.CROP_MIN_SIZE = 0.3
+
+# -----------------------------------------------------------------------------
+# Dataset
+# -----------------------------------------------------------------------------
+_C.DATASETS = CN()
+# List of the dataset names for training, as present in paths_catalog.py
+_C.DATASETS.TRAIN = ()
+# List of the dataset names for testing, as present in paths_catalog.py
+_C.DATASETS.TEST = ()
+# Use is_crowd label
+_C.DATASETS.USE_CROWD = False
+_C.DATASETS.CLASS_AGNOSTIC = False
+_C.DATASETS.CLASS_CONCAT = False
+_C.DATASETS.MAX_BOX = -1
+_C.DATASETS.SAMPLE_RATIO = 0.0
+_C.DATASETS.FEW_SHOT = 0
+# SHUFFLE_SEED != 0 means shuffle the dataset in the few shot setting
+_C.DATASETS.SHUFFLE_SEED = 0
+_C.DATASETS.PREDEFINED_TEXT = ''
+_C.DATASETS.ALTERNATIVE_TRAINING = False
+_C.DATASETS.MULTISTAGE_TRAINING = False
+_C.DATASETS.REGISTER = CN(new_allowed=True)
+_C.DATASETS.BOX_THRESHOLD = 0.1
+# Duplicate Dataset
+_C.DATASETS.COCO_COPY = 1
+_C.DATASETS.LVIS_COPY = 1
+_C.DATASETS.FLICKR_COPY = 1
+_C.DATASETS.MIXED_COPY = 1
+_C.DATASETS.OBJECT365_COPY = 1
+_C.DATASETS.VG_COPY = 1
+_C.DATASETS.OI_COPY = 1
+_C.DATASETS.IN_COPY = 1
+
+# Duplicate Dataset
+_C.DATASETS.COCO_COPY = 1
+_C.DATASETS.FLICKR_COPY = 1
+_C.DATASETS.MIXED_COPY = 1
+_C.DATASETS.OBJECT365_COPY = 1
+_C.DATASETS.VG_COPY = 1
+_C.DATASETS.OI_COPY = 1
+_C.DATASETS.IN_COPY = 1
+_C.DATASETS.GENERAL_COPY = -1
+_C.DATASETS.GENERAL_COPY_TEST = -1
+
+# OD to Grounding
+_C.DATASETS.RANDOM_SAMPLE_NEG = -1
+_C.DATASETS.ADD_DET_PROMPT = False
+_C.DATASETS.ADD_DET_PROMPT_ADVANCED = False
+_C.DATASETS.USE_OD_AUG = False
+_C.DATASETS.USE_COCO_FORMAT = False
+_C.DATASETS.CONTROL_PROB = ()
+_C.DATASETS.DISABLE_SHUFFLE = False
+_C.DATASETS.PROMPT_VERSION = ""
+_C.DATASETS.PROMPT_LIMIT_NEG = -1
+_C.DATASETS.POS_QUESTION_PROB = 0.6
+_C.DATASETS.NEG_QUESTION_PROB = 0.8
+_C.DATASETS.FULL_QUESTION_PROB = 0.5
+_C.DATASETS.ONE_HOT = False
+_C.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT = False
+
+_C.DATASETS.DISABLE_CLIP_TO_IMAGE = False
+_C.DATASETS.SEPARATION_TOKENS = " "
+
+# LVIS
+_C.DATASETS.LVIS_USE_NORMAL_AP = False
+_C.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING = False
+
+# Caption
+_C.DATASETS.BING_INDEX_LIST = []
+_C.DATASETS.CAPTION_MIN_BOX = 1
+_C.DATASETS.REPLACE_CLEAN_LABEL = False
+_C.DATASETS.FURTHER_SCREEN = False
+_C.DATASETS.CAPTION_CONF = 0.9
+_C.DATASETS.CAPTION_NMS = 0.9
+_C.DATASETS.PACK_RANDOM_CAPTION_NUMBER = 0
+_C.DATASETS.INFERENCE_CAPTION = False
+_C.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA = -1.0
+_C.DATASETS.RANDOM_PACK_PROB = -1.0
+_C.DATASETS.NO_RANDOM_PACK_PROBABILITY = 0.0
+_C.DATASETS.SAFEGUARD_POSITIVE_CAPTION = True
+_C.DATASETS.CAPTION_FORMAT_VERSION = "v1"
+_C.DATASETS.LOCAL_DEBUG = False
+
+
+# Od in the wild
+_C.DATASETS.PREDEFINED_TEXT = None
+_C.DATASETS.TRAIN_DATASETNAME_SUFFIX = ""
+_C.DATASETS.TEST_DATASETNAME_SUFFIX = ""
+_C.DATASETS.OVERRIDE_CATEGORY = None
+_C.DATASETS.USE_OVERRIDE_CATEGORY = False
+_C.DATASETS.SUPRESS_QUERY = None
+_C.DATASETS.USE_SUPRESS_QUERY = False
+_C.DATASETS.USE_CAPTION_PROMPT = False
+_C.DATASETS.CAPTION_PROMPT = None
+
+_C.DATASETS.FLICKR_GT_TYPE = "separate"
+
+# VQA
+_C.DATASETS.DIVER_BOX_FOR_VQA = False
+# -----------------------------------------------------------------------------
+# DataLoader
+# -----------------------------------------------------------------------------
+_C.DATALOADER = CN()
+# Number of data loading threads
+_C.DATALOADER.NUM_WORKERS = 4
+# If > 0, this enforces that each collated batch should have a size divisible
+# by SIZE_DIVISIBILITY
+_C.DATALOADER.SIZE_DIVISIBILITY = 0
+# If True, each batch should contain only images for which the aspect ratio
+# is compatible. This groups portrait images together, and landscape images
+# are not batched with portrait images.
+_C.DATALOADER.ASPECT_RATIO_GROUPING = True
+# Define min number of keypoints required from GT, for example 10 out of 17
+_C.DATALOADER.MIN_KPS_PER_IMS = 0
+# Use random sampler during training
+_C.DATALOADER.USE_RANDOM_SEED = False
+
+_C.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE = False
+# ---------------------------------------------------------------------------- #
+# Backbone options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.BACKBONE = CN()
+
+# The backbone conv body to use
+# The string must match a function that is imported in modeling.model_builder
+# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN
+# backbone)
+_C.MODEL.BACKBONE.CONV_BODY = "R-50-C4"
+
+# Add StopGrad at a specified stage so the bottom layers are frozen
+_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2
+_C.MODEL.BACKBONE.FREEZE = False
+_C.MODEL.BACKBONE.GROUP = 1
+_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4
+# Option to reset bn running statics
+_C.MODEL.BACKBONE.RESET_BN = False
+# Backbone Normalization Level
+_C.MODEL.BACKBONE.NORM_LEVEL = 3
+# BN for backbone
+_C.MODEL.BACKBONE.USE_BN = False
+# Sync BN for backbone
+_C.MODEL.BACKBONE.USE_SYNCBN = False
+_C.MODEL.BACKBONE.USE_NSYNCBN = False
+# GN for backbone
+_C.MODEL.BACKBONE.USE_GN = False
+# Evo Norm for backbone
+_C.MODEL.BACKBONE.USE_EN = False
+# Layers for backbone
+_C.MODEL.BACKBONE.USE_DFCONV = False
+_C.MODEL.BACKBONE.USE_DYRELU = False
+_C.MODEL.BACKBONE.USE_SE = False
+_C.MODEL.BACKBONE.LAYER_SETUP = (3, 4, 6, 3)
+_C.MODEL.BACKBONE.LAYER_SEARCH = CN(new_allowed=True)
+_C.MODEL.BACKBONE.OUT_FEATURES = ("stage2", "stage3", "stage4", "stage5")
+_C.MODEL.BACKBONE.FPN_LAYER = ()
+_C.MODEL.BACKBONE.USE_CHECKPOINT = False
+# Add JF efficient det cfgs
+_C.MODEL.BACKBONE.EFFICIENT_DET_START_FROM = 3
+_C.MODEL.BACKBONE.EFFICIENT_DET_COMPOUND = 0
+_C.MODEL.BACKBONE.EFFICIENT_DET_BIFPN_VERSION = 0
+
+_C.MODEL.LANGUAGE_BACKBONE = CN()
+_C.MODEL.LANGUAGE_BACKBONE.WEIGHT = ""
+_C.MODEL.LANGUAGE_BACKBONE.FREEZE = False
+_C.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT = False
+_C.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE = "bert-base-uncased"
+_C.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE = "bert-base-uncased"
+_C.MODEL.LANGUAGE_BACKBONE.LANG_DIM = 768
+_C.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN = 256
+_C.MODEL.LANGUAGE_BACKBONE.N_LAYERS = 1
+_C.MODEL.LANGUAGE_BACKBONE.UNUSED_TOKEN = 106
+_C.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL = False
+
+_C.MODEL.LANGUAGE_BACKBONE.RNN_TYPE = "lstm"
+_C.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH = True
+_C.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE = 512
+_C.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE = 512
+_C.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE = 512
+_C.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL = True
+_C.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P = 0.5
+_C.MODEL.LANGUAGE_BACKBONE.DROPOUT_P = 0.2
+_C.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH = ""
+_C.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE = 0
+
+_C.MODEL.LANGUAGE_BACKBONE.PAD_MAX = True
+# ---------------------------------------------------------------------------- #
+# FPN options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.FPN = CN()
+_C.MODEL.FPN.FREEZE = False
+_C.MODEL.FPN.USE_GN = False
+_C.MODEL.FPN.USE_RELU = False
+_C.MODEL.FPN.USE_DYRELU = False
+_C.MODEL.FPN.DROP_BLOCK = True
+_C.MODEL.FPN.DROP_PROB = 0.3
+_C.MODEL.FPN.DROP_SIZE = 3
+_C.MODEL.FPN.USE_SPP = False
+_C.MODEL.FPN.USE_PAN = False
+_C.MODEL.FPN.USE_DYHEAD = False
+_C.MODEL.FPN.RETURN_SWINT_FEATURE_BEFORE_FUSION = False
+# ---------------------------------------------------------------------------- #
+# BIFPN options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.BIFPN = CN()
+_C.MODEL.BIFPN.NUM_REPEATS = 1
+_C.MODEL.BIFPN.USE_ATTENTION = True
+
+# ---------------------------------------------------------------------------- #
+# Group Norm options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.GROUP_NORM = CN()
+# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS)
+_C.MODEL.GROUP_NORM.DIM_PER_GP = -1
+# Number of groups in GroupNorm (-1 if using DIM_PER_GP)
+_C.MODEL.GROUP_NORM.NUM_GROUPS = 16
+# GroupNorm's small constant in the denominator
+_C.MODEL.GROUP_NORM.EPSILON = 1e-5
+
+# ---------------------------------------------------------------------------- #
+# Evo Norm options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.EVO_NORM = CN()
+# Number of groups in EvoNorm (-1 if using DIM_PER_GP)
+_C.MODEL.EVO_NORM.NUM_GROUPS = 8
+# EvoNorm's small constant in the denominator
+_C.MODEL.EVO_NORM.EPSILON = 1e-5
+
+# ---------------------------------------------------------------------------- #
+# RetinaNet Options (Follow the Detectron version)
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RETINANET = CN()
+# This is the number of foreground classes and background.
+_C.MODEL.RETINANET.NUM_CLASSES = 81
+# Convolutions to use in the cls and bbox tower
+# NOTE: this doesn't include the last conv for logits
+_C.MODEL.RETINANET.NUM_CONVS = 4
+# During inference, #locs to select based on cls score before NMS is performed
+# per FPN level
+_C.MODEL.RETINANET.PRE_NMS_TOP_N = 1000
+# Prior prob for the positives at the beginning of training. This is used to set
+# the bias init for the logits layer
+_C.MODEL.RETINANET.PRIOR_PROB = 0.01
+# Inference cls score threshold, anchors with score > INFERENCE_TH are
+# considered for inference
+_C.MODEL.RETINANET.INFERENCE_TH = 0.05
+# NMS threshold used in RetinaNet
+_C.MODEL.RETINANET.NMS_TH = 0.4
+_C.MODEL.RETINANET.DETECTIONS_PER_IMG = 100
+
+# ---------------------------------------------------------------------------- #
+# Focal Loss Options (Follow the Detectron version)
+# ---------------------------------------------------------------------------- #
+_C.MODEL.FOCAL = CN()
+# Weight for bbox_regression loss
+_C.MODEL.FOCAL.BBOX_REG_WEIGHT = 4.0
+# Smooth L1 loss beta for bbox regression
+_C.MODEL.FOCAL.BBOX_REG_BETA = 0.11
+# IoU overlap ratio for labeling an anchor as positive
+# Anchors with >= iou overlap are labeled positive
+_C.MODEL.FOCAL.FG_IOU_THRESHOLD = 0.5
+# IoU overlap ratio for labeling an anchor as negative
+# Anchors with < iou overlap are labeled negative
+_C.MODEL.FOCAL.BG_IOU_THRESHOLD = 0.4
+# Focal loss parameter: alpha
+_C.MODEL.FOCAL.LOSS_ALPHA = 0.25
+# Focal loss parameter: gamma
+_C.MODEL.FOCAL.LOSS_GAMMA = 2.0
+
+# ---------------------------------------------------------------------------- #
+# FCOS Options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.FCOS = CN()
+_C.MODEL.FCOS.NUM_CLASSES = 81  # the number of classes including background
+_C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128]
+_C.MODEL.FCOS.PRIOR_PROB = 0.01
+_C.MODEL.FCOS.INFERENCE_TH = 0.05
+_C.MODEL.FCOS.NMS_TH = 0.6
+_C.MODEL.FCOS.PRE_NMS_TOP_N = 1000
+
+# the number of convolutions used in the cls and bbox tower
+_C.MODEL.FCOS.NUM_CONVS = 4
+# if use deformable conv to align features
+_C.MODEL.FCOS.USE_DFCONV = False
+
+# if CENTER_SAMPLING_RADIUS <= 0, it will disable center sampling
+_C.MODEL.FCOS.CENTER_SAMPLING_RADIUS = 0.0
+# IOU_LOSS_TYPE can be "iou", "linear_iou" or "giou"
+_C.MODEL.FCOS.IOU_LOSS_TYPE = "iou"
+
+_C.MODEL.FCOS.NORM_REG_TARGETS = False
+_C.MODEL.FCOS.CENTERNESS_ON_REG = False
+_C.MODEL.FCOS.USE_GT_CENTER = False
+
+_C.MODEL.FCOS.DETECTIONS_PER_IMG = 100
+_C.MODEL.FCOS.USE_GN = False
+_C.MODEL.FCOS.USE_BN = False
+
+_C.MODEL.FCOS.INFERENCE_TH_TRAIN = 0.0
+_C.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN = 3000
+_C.MODEL.FCOS.POST_NMS_TOP_N_TRAIN = 1000
+
+# ---------------------------------------------------------------------------- #
+# ATSS Options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ATSS = CN()
+_C.MODEL.ATSS.NUM_CLASSES = 81  # the number of classes including background
+_C.MODEL.ATSS.PRIOR_PROB = 0.01
+_C.MODEL.ATSS.INFERENCE_TH = 0.05
+_C.MODEL.ATSS.NMS_TH = 0.6
+_C.MODEL.ATSS.PRE_NMS_TOP_N = 1000
+
+# the number of convolutions used in the cls and bbox tower
+_C.MODEL.ATSS.NUM_CONVS = 4
+# the channels of convolutions used in the cls and bbox tower
+_C.MODEL.ATSS.CHANNELS = 128
+# if use deformable conv to align features
+_C.MODEL.ATSS.USE_DFCONV = False
+
+# topk for selecting candidate positive samples from each level
+_C.MODEL.ATSS.TOPK = 9
+
+# Weight for bbox_regression loss
+_C.MODEL.ATSS.REG_LOSS_WEIGHT = 2.0
+
+_C.MODEL.ATSS.DETECTIONS_PER_IMG = 100
+_C.MODEL.ATSS.USE_GN = False
+_C.MODEL.ATSS.USE_BN = False
+
+_C.MODEL.ATSS.USE_DYRELU = False
+_C.MODEL.ATSS.USE_SE = False
+
+_C.MODEL.ATSS.INFERENCE_TH_TRAIN = 0.0
+_C.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN = 3000
+_C.MODEL.ATSS.POST_NMS_TOP_N_TRAIN = 1000
+# ---------------------------------------------------------------------------- #
+# DYHEAD Options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.DYHEAD = CN()
+_C.MODEL.DYHEAD.NUM_CLASSES = 81  # the number of classes including background
+_C.MODEL.DYHEAD.PRIOR_PROB = 0.01
+
+# the number of convolutions used in the cls and bbox tower
+_C.MODEL.DYHEAD.NUM_CONVS = 4
+# the channels of convolutions used in the cls and bbox tower
+_C.MODEL.DYHEAD.CHANNELS = 128
+_C.MODEL.DYHEAD.GROUPS = 1
+# if use deformable conv to align features
+_C.MODEL.DYHEAD.USE_DFCONV = False
+
+# topk for selecting candidate positive samples from each level
+_C.MODEL.DYHEAD.TOPK = 9
+
+_C.MODEL.DYHEAD.SCORE_AGG = "MEAN"  # MEAN or MAX, for binary focal loss score aggregation
+
+_C.MODEL.DYHEAD.LOG_SCALE = 0.0  # temperature (dot product)
+_C.MODEL.DYHEAD.SHALLOW_LOG_SCALE = 0.0  # # temperature (shallow contrastive)
+
+_C.MODEL.DYHEAD.USE_GN = False
+_C.MODEL.DYHEAD.USE_NSYNCBN = False
+_C.MODEL.DYHEAD.USE_SYNCBN = False
+
+_C.MODEL.DYHEAD.USE_DYFUSE = False
+_C.MODEL.DYHEAD.USE_DYRELU = False
+
+_C.MODEL.DYHEAD.CONV_FUNC = ''
+
+# CosineSimOutputLayers: https://github.com/ucbdrive/few-shot-object-detection/blob/master/fsdet/modeling/roi_heads/fast_rcnn.py#L448-L464
+_C.MODEL.DYHEAD.COSINE_SCALE = -1.0
+
+_C.MODEL.DYHEAD.FUSE_CONFIG = CN()
+_C.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.TYPE = ""
+_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE = 256
+_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE = 256
+_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT = 0.1
+_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS = 2
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS = False
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT = 1.0
+_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA = 2.0
+_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA = 0.25
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM = 64
+_C.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT = 1.0
+_C.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT = 1.0
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE = True
+_C.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D = False
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT = False
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT = False
+
+# Controls for 
+_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT = False
+
+# MLM Loss
+_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES = True
+_C.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF = 1.0
+_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE  = False
+
+# Shallow Contrastive Loss (FPN)
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS = 100
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS = False
+_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM = 64
+_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT = 1.0
+
+# Shallow Contrastive Loss (BACKBONE)
+_C.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS = False
+
+_C.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False
+
+# use checkpoint to save memory
+_C.MODEL.DYHEAD.USE_CHECKPOINT = False
+
+# ---------------------------------------------------------------------------- #
+# RPN options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RPN = CN()
+_C.MODEL.RPN.USE_FPN = False
+# Base RPN anchor sizes given in absolute pixels w.r.t. the scaled network input
+_C.MODEL.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512)
+# Stride of the feature map that RPN is attached.
+# For FPN, number of strides should match number of scales
+_C.MODEL.RPN.ANCHOR_STRIDE = (16,)
+# RPN anchor aspect ratios
+_C.MODEL.RPN.ASPECT_RATIOS = (0.5, 1.0, 2.0)
+# Anchor shift away ration from the center for r,t,l,d
+_C.MODEL.RPN.ANCHOR_SHIFT = (0.0, 0.0, 0.0, 0.0)
+# Use center to decide anchor size
+_C.MODEL.RPN.USE_RELATIVE_SIZE = False
+# Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels
+# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
+_C.MODEL.RPN.STRADDLE_THRESH = 0
+# Anchor scales per octave for complex anchors
+_C.MODEL.RPN.OCTAVE = 2.0
+_C.MODEL.RPN.SCALES_PER_OCTAVE = 3
+# Minimum overlap required between an anchor and ground-truth box for the
+# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
+# ==> positive RPN example)
+_C.MODEL.RPN.FG_IOU_THRESHOLD = 0.7
+# Maximum overlap allowed between an anchor and ground-truth box for the
+# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
+# ==> negative RPN example)
+_C.MODEL.RPN.BG_IOU_THRESHOLD = 0.3
+# Total number of RPN examples per image
+_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
+# Target fraction of foreground (positive) examples per RPN minibatch
+_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
+# Number of top scoring RPN proposals to keep before applying NMS
+# When FPN is used, this is *per FPN level* (not total)
+_C.MODEL.RPN.PRE_NMS_TOP_N_TRAIN = 12000
+_C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000
+# Number of top scoring RPN proposals to keep after applying NMS
+_C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000
+_C.MODEL.RPN.POST_NMS_TOP_N_TEST = 1000
+# NMS threshold used on RPN proposals
+_C.MODEL.RPN.NMS_THRESH = 0.7
+# Proposal height and width both need to be greater than RPN_MIN_SIZE
+# (a the scale used during training or inference)
+_C.MODEL.RPN.MIN_SIZE = 0
+# Number of top scoring RPN proposals to keep after combining proposals from
+# all FPN levels
+_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
+_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
+# Custom rpn head, empty to use default conv or separable conv
+_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"
+_C.MODEL.RPN.FREEZE = False
+_C.MODEL.RPN.FORCE_BOXES = False
+_C.MODEL.RPN.RETURN_FUSED_FEATURES = False
+
+# ---------------------------------------------------------------------------- #
+# ROI HEADS options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_HEADS = CN()
+_C.MODEL.ROI_HEADS.USE_FPN = False
+# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
+_C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5
+# Overlap threshold for an RoI to be considered background
+# (class = 0 if overlap in [0, BG_IOU_THRESHOLD))
+_C.MODEL.ROI_HEADS.BG_IOU_THRESHOLD = 0.5
+# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
+# These are empirically chosen to approximately lead to unit variance targets
+_C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10., 10., 5., 5.)
+# RoI minibatch size *per image* (number of regions of interest [ROIs])
+# Total number of RoIs per training minibatch =
+#   TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS
+# E.g., a common configuration is: 512 * 2 * 8 = 8192
+_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
+# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
+_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
+
+# Only used on test mode
+
+# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
+# balance obtaining high recall with not having too many low precision
+# detections that will slow down inference post processing steps (like NMS)
+_C.MODEL.ROI_HEADS.SCORE_THRESH = 0.05
+# Overlap threshold used for non-maximum suppression (suppress boxes with
+# IoU >= this threshold)
+_C.MODEL.ROI_HEADS.NMS = 0.5
+# Maximum number of detections to return per image (100 is based on the limit
+# established for the COCO dataset)
+_C.MODEL.ROI_HEADS.DETECTIONS_PER_IMG = 100
+
+_C.MODEL.ROI_BOX_HEAD = CN()
+_C.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
+_C.MODEL.ROI_BOX_HEAD.PREDICTOR = "FastRCNNPredictor"
+_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
+_C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,)
+_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81
+# Hidden layer dimension when using an MLP for the RoI box head
+_C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024
+# GN
+_C.MODEL.ROI_BOX_HEAD.USE_GN = False
+# Dilation
+_C.MODEL.ROI_BOX_HEAD.DILATION = 1
+_C.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM = 256
+_C.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS = 4
+# Use D2 style ROIAlignV2
+_C.MODEL.ROI_BOX_HEAD.POOLER_ALIGNED = False
+
+_C.MODEL.ROI_MASK_HEAD = CN()
+_C.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
+_C.MODEL.ROI_MASK_HEAD.PREDICTOR = "MaskRCNNC4Predictor"
+_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
+_C.MODEL.ROI_MASK_HEAD.POOLER_SCALES = (1.0 / 16,)
+_C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024
+_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
+_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
+_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
+# Whether or not resize and translate masks to the input image.
+_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
+_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5
+# Dilation
+_C.MODEL.ROI_MASK_HEAD.DILATION = 1
+# GN
+_C.MODEL.ROI_MASK_HEAD.USE_GN = False
+# HG
+_C.MODEL.ROI_MASK_HEAD.HG_SCALE = 1
+
+_C.MODEL.ROI_KEYPOINT_HEAD = CN()
+_C.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR = "KeypointRCNNFeatureExtractor"
+_C.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR = "KeypointRCNNPredictor"
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES = (1.0 / 16,)
+_C.MODEL.ROI_KEYPOINT_HEAD.MLP_HEAD_DIM = 1024
+_C.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS = tuple(512 for _ in range(8))
+_C.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION = 14
+_C.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES = 17
+_C.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME = ()  # If left empty, use default names
+_C.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
+
+# ---------------------------------------------------------------------------- #
+# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
+# Note that parts of a resnet may be used for both the backbone and the head
+# These options apply to both
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RESNETS = CN()
+
+_C.MODEL.RESNETS.USE_STEM3X3 = False
+_C.MODEL.RESNETS.WITH_SE = False
+_C.MODEL.RESNETS.USE_AVG_DOWN = False
+
+# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
+_C.MODEL.RESNETS.NUM_GROUPS = 1
+
+# Baseline width of each group
+_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
+
+# Place the stride 2 conv on the 1x1 filter
+# Use True only for the original MSRA ResNet; use False for C2 and Torch models
+_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
+
+# Residual transformation function
+_C.MODEL.RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm"
+# ResNet's stem function (conv1 and pool1)
+_C.MODEL.RESNETS.STEM_FUNC = "StemWithFixedBatchNorm"
+
+# Apply dilation in stage "res5"
+_C.MODEL.RESNETS.RES5_DILATION = 1
+
+_C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
+_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
+_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
+
+_C.MODEL.RESNETS.REVISION = "resnet_light"
+# Deformable convolutions
+_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False)
+_C.MODEL.RESNETS.WITH_MODULATED_DCN = False
+_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1
+
+# ---------------------------------------------------------------------------- #
+# Swin Transformer
+# ---------------------------------------------------------------------------- #
+_C.MODEL.SWINT = CN()
+_C.MODEL.SWINT.EMBED_DIM = 96
+_C.MODEL.SWINT.OUT_CHANNELS = (96, 192, 384, 768)
+_C.MODEL.SWINT.DEPTHS = (2, 2, 6, 2)
+_C.MODEL.SWINT.NUM_HEADS = (3, 6, 12, 24)
+_C.MODEL.SWINT.WINDOW_SIZE = 7
+_C.MODEL.SWINT.MLP_RATIO = 4
+_C.MODEL.SWINT.DROP_PATH_RATE = 0.2
+_C.MODEL.SWINT.APE = False
+_C.MODEL.SWINT.VERSION = "v1"
+_C.MODEL.SWINT.OUT_NORM = True
+_C.MODEL.SWINT.LAYER_SCALE = 0
+
+# ---------------------------------------------------------------------------- #
+# CVT SPEC
+# ---------------------------------------------------------------------------- #
+_C.MODEL.SPEC = CN(new_allowed=True)
+
+# ---------------------------------------------------------------------------- #
+# CLIP SPEC
+# ---------------------------------------------------------------------------- #
+_C.MODEL.CLIP = CN()
+_C.MODEL.CLIP.CONTEXT_LENGTH = 256  # default 77
+_C.MODEL.CLIP.WIDTH = 512
+_C.MODEL.CLIP.LAYERS = 12
+_C.MODEL.CLIP.HEADS = 8
+_C.MODEL.CLIP.DROP_PATH = 0.0
+_C.MODEL.CLIP.TOKENIZER = "clip"
+_C.MODEL.CLIP.VOCAB_SIZE = 49408
+
+# ---------------------------------------------------------------------------- #
+# SEARCH
+# ---------------------------------------------------------------------------- #
+
+_C.SEARCH = CN()
+_C.SEARCH.MAX_EPOCH = 20
+_C.SEARCH.SELECT_NUM = 20
+_C.SEARCH.POPULATION_NUM = 64
+_C.SEARCH.MUTATION_NUM = 24
+_C.SEARCH.CROSSOVER_NUM = 24
+_C.SEARCH.MUTATION_PROB = 0.1
+
+# ---------------------------------------------------------------------------- #
+# Solver
+# ---------------------------------------------------------------------------- #
+_C.SOLVER = CN()
+_C.SOLVER.USE_AMP = False
+
+_C.SOLVER.MAX_ITER = 40000
+_C.SOLVER.MULTI_MAX_ITER = ()  # set different max epoch for different stage
+_C.SOLVER.MAX_EPOCH = 0  # any epoch number>0 will overwrite max_iter
+_C.SOLVER.MULTI_MAX_EPOCH = ()  # set different max epoch for different stage
+
+_C.SOLVER.OPTIMIZER = "SGD"  # "ADAMW"
+
+_C.SOLVER.BASE_LR = 0.001
+
+_C.SOLVER.LANG_LR = 0.00001
+_C.SOLVER.BACKBONE_BODY_LR_FACTOR = 1.0
+
+_C.SOLVER.BIAS_LR_FACTOR = 2
+_C.SOLVER.GRAD_CLIP = 0.0
+# D2 gradient clip
+_C.SOLVER.CLIP_GRADIENTS = CN()
+_C.SOLVER.CLIP_GRADIENTS.ENABLED = False
+_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.0
+_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model"
+_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
+_C.SOLVER.MODEL_EMA = 0.0
+
+_C.SOLVER.MOMENTUM = 0.9
+
+_C.SOLVER.WEIGHT_DECAY = 0.0005
+_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0
+_C.SOLVER.WEIGHT_DECAY_NORM_FACTOR = 1.0
+
+# use cosine lr to replace default multistage
+_C.SOLVER.USE_COSINE = False
+_C.SOLVER.MIN_LR = 0.000001
+
+_C.SOLVER.GAMMA = 0.1
+_C.SOLVER.STEPS = (30000,)
+
+_C.SOLVER.USE_AUTOSTEP = False
+_C.SOLVER.STEP_PATIENCE = 5
+
+_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
+_C.SOLVER.WARMUP_ITERS = 500
+_C.SOLVER.WARMUP_METHOD = "linear"
+
+_C.SOLVER.CHECKPOINT_PERIOD = 2500
+_C.SOLVER.CHECKPOINT_PER_EPOCH = -1.0
+_C.SOLVER.TEST_WITH_INFERENCE = False
+_C.SOLVER.AUTO_TERMINATE_PATIENCE = -1
+# Number of images per batch
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.SOLVER.IMS_PER_BATCH = 16
+# This is the max negative ratio allowed per batch
+_C.SOLVER.MAX_NEG_PER_BATCH = 0.1
+
+_C.SOLVER.SEED = 0
+_C.SOLVER.DISABLE_OUTPUT_DISTRIBUTED = False
+
+
+_C.SOLVER.PROMPT_PROBING_LEVEL = -1.0 
+# -1 means tuning the whole model; 
+# 1 means tuning the whole language model; 1.5 means tuning the box head as well
+
+_C.SOLVER.FIND_UNUSED_PARAMETERS = True
+_C.SOLVER.DATASET_LENGTH = -1 # Just for logging purpose
+_C.SOLVER.TUNING_HIGHLEVEL_OVERRIDE = None
+_C.SOLVER.USE_EMA_FOR_MONITOR = False
+
+_C.SOLVER.WEIGHT_DECAY_SCHEDULE = False
+_C.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO = 0.667
+
+# ---------------------------------------------------------------------------- #
+# Specific test options
+# ---------------------------------------------------------------------------- #
+_C.TEST = CN()
+_C.TEST.EXPECTED_RESULTS = []
+_C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4
+_C.TEST.DURING_TRAINING = False
+# Number of images per batch
+# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
+# see 2 images per batch
+_C.TEST.IMS_PER_BATCH = 16
+# Special Test Configuration
+_C.TEST.USE_MULTISCALE = False
+# _C.TEST.SCALES = (400, 600, 800, 1000, 1200, 1400)
+# _C.TEST.RANGES = ((96, 10000), (64, 10000), (0, 10000), (0, 10000), (0, 256), (0, 192))
+_C.TEST.SCALES = (400, 500, 600, 640, 700, 900, 1000, 1100, 1200, 1300, 1400, 1800)
+_C.TEST.RANGES = ((96, 10000), (96, 10000), (64, 10000), (64, 10000), (64, 10000), (0, 10000), (0, 10000), (0, 256), (0, 256), (0, 192), (0, 192), (0, 96))
+_C.TEST.MAX_SIZE = 2500
+_C.TEST.FLIP = True
+_C.TEST.SPECIAL_NMS = 'none'  # ('none', 'soft-nms', 'vote', 'soft-vote')
+_C.TEST.TH = 0.6  # threshold for nms or vote
+_C.TEST.PRE_NMS_TOP_N = 1000
+_C.TEST.NUM_CLASSES = 81
+_C.TEST.SELECT_CLASSES = ()
+
+_C.TEST.EVAL_TASK = ""
+_C.TEST.SUBSET = -1
+_C.TEST.CHUNKED_EVALUATION = -1
+_C.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM = -1
+# ---------------------------------------------------------------------------- #
+# Misc options
+# ---------------------------------------------------------------------------- #
+_C.OUTPUT_DIR = "OUTPUT"
+
+_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")
+
+# TensorBoard experiment location
+_C.TENSORBOARD_EXP = "OUTPUT"
+
+
+_C.GLIPKNOW = CN()
+_C.GLIPKNOW.KNOWLEDGE_FILE = ""
+_C.GLIPKNOW.KNOWLEDGE_TYPE = ""
+_C.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN = -1
+_C.GLIPKNOW.PARALLEL_LANGUAGE_INPUT = False
+_C.GLIPKNOW.LAN_FEATURE_AGG_TYPE = "first"
+_C.GLIPKNOW.GPT3_NUM = 5
+_C.GLIPKNOW.WIKI_AND_GPT3 = False
\ No newline at end of file
diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py
new file mode 100644
index 0000000000000000000000000000000000000000..be63e5715434d696cb1480c8a5b436b642808afb
--- /dev/null
+++ b/maskrcnn_benchmark/config/paths_catalog.py
@@ -0,0 +1,447 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""Centralized catalog of paths."""
+
+import os
+
+
+def try_to_find(file, return_dir=False, search_path=['./DATASET', './OUTPUT', './data', './MODEL']):
+    if not file:
+        return file
+
+    if file.startswith('catalog://'):
+        return file
+
+    DATASET_PATH = ['./']
+    if 'DATASET' in os.environ:
+        DATASET_PATH.append(os.environ['DATASET'])
+    DATASET_PATH += search_path
+
+    for path in DATASET_PATH:
+        if os.path.exists(os.path.join(path, file)):
+            if return_dir:
+                return path
+            else:
+                return os.path.join(path, file)
+
+    print('Cannot find {} in {}'.format(file, DATASET_PATH))
+    exit(1)
+
+
+class DatasetCatalog(object):
+    DATASETS = {
+        # pretrained grounding dataset
+        # mixed vg and coco
+        "mixed_train": {
+            "coco_img_dir": "coco/train2014",
+            "vg_img_dir": "gqa/images",
+            "ann_file": "mdetr_annotations/final_mixed_train.json",
+        },
+        "mixed_train_no_coco": {
+            "coco_img_dir": "coco/train2014",
+            "vg_img_dir": "gqa/images",
+            "ann_file": "mdetr_annotations/final_mixed_train_no_coco.json",
+        },
+
+        # flickr30k
+        "flickr30k_train": {
+            "img_folder": "flickr30k/flickr30k_images/train",
+            "ann_file": "mdetr_annotations/final_flickr_separateGT_train.json",
+            "is_train": True
+        },
+        "flickr30k_val": {
+            "img_folder": "flickr30k/flickr30k_images/val",
+            "ann_file": "mdetr_annotations/final_flickr_separateGT_val.json",
+            "is_train": False
+        },
+        "flickr30k_test": {
+            "img_folder": "flickr30k/flickr30k_images/test",
+            "ann_file": "mdetr_annotations/final_flickr_separateGT_test.json",
+            "is_train": False
+        },
+
+        # refcoco
+        "refexp_all_val": {
+            "img_dir": "refcoco/train2014",
+            "ann_file": "mdetr_annotations/final_refexp_val.json",
+            "is_train": False
+        },
+
+        # gqa
+        "gqa_val": {
+            "img_dir": "gqa/images",
+            "ann_file": "mdetr_annotations/final_gqa_val.json",
+            "is_train": False
+        },
+
+        # phrasecut
+        "phrasecut_train": {
+            "img_dir": "gqa/images",
+            "ann_file": "mdetr_annotations/finetune_phrasecut_train.json",
+            "is_train": True
+        },
+
+
+        # od to grounding
+        # coco tsv
+        "coco_dt_train": {
+            "dataset_file": "coco_dt",
+            "yaml_path": "coco_tsv/coco_obj.yaml",
+            "is_train": True,
+        },
+        "COCO_odinw_train_8copy_dt_train": {
+            "dataset_file": "coco_odinw_dt",
+            "yaml_path": "coco_tsv/COCO_odinw_train_8copy.yaml",
+            "is_train": True,
+        },
+        "COCO_odinw_val_dt_train": {
+            "dataset_file": "coco_odinw_dt",
+            "yaml_path": "coco_tsv/COCO_odinw_val.yaml",
+            "is_train": False,
+        },
+        # lvis tsv
+        "lvisv1_dt_train": {
+            "dataset_file": "lvisv1_dt",
+            "yaml_path": "coco_tsv/LVIS_v1_train.yaml",
+            "is_train": True,
+        },
+        "LVIS_odinw_train_8copy_dt_train": {
+            "dataset_file": "coco_odinw_dt",
+            "yaml_path": "coco_tsv/LVIS_odinw_train_8copy.yaml",
+            "is_train": True,
+        },
+        # object365 tsv
+        "object365_dt_train": {
+            "dataset_file": "object365_dt",
+            "yaml_path": "Objects365/objects365_train_vgoiv6.cas2000.yaml",
+            "is_train": True,
+        },
+        "object365_odinw_2copy_dt_train": {
+            "dataset_file": "object365_odinw_dt",
+            "yaml_path": "Objects365/objects365_train_odinw.cas2000_2copy.yaml",
+            "is_train": True,
+        },
+        "objects365_odtsv_train": {
+            "dataset_file": "objects365_odtsv",
+            "yaml_path": "Objects365/train.cas2000.yaml",
+            "is_train": True,
+        },
+        "objects365_odtsv_val": {
+            "dataset_file": "objects365_odtsv",
+            "yaml_path": "Objects365/val.yaml",
+            "is_train": False,
+        },
+
+        # ImagetNet OD
+        "imagenetod_train_odinw_2copy_dt": {
+            "dataset_file": "imagenetod_odinw_dt",
+            "yaml_path": "imagenet_od/imagenetod_train_odinw_2copy.yaml",
+            "is_train": True,
+        },
+
+        # OpenImage OD
+        "oi_train_odinw_dt": {
+            "dataset_file": "oi_odinw_dt",
+            "yaml_path": "openimages_v5c/oi_train_odinw.cas.2000.yaml",
+            "is_train": True,
+        },
+
+        # vg tsv
+        "vg_dt_train": {
+            "dataset_file": "vg_dt",
+            "yaml_path": "visualgenome/train_vgoi6_clipped.yaml",
+            "is_train": True,
+        },
+
+        "vg_odinw_clipped_8copy_dt_train": {
+            "dataset_file": "vg_odinw_clipped_8copy_dt",
+            "yaml_path": "visualgenome/train_odinw_clipped_8copy.yaml",
+            "is_train": True,
+        },
+        "vg_vgoi6_clipped_8copy_dt_train": {
+            "dataset_file": "vg_vgoi6_clipped_8copy_dt",
+            "yaml_path": "visualgenome/train_vgoi6_clipped_8copy.yaml",
+            "is_train": True,
+        },
+
+        # coco json
+        "coco_grounding_train": {
+            "img_dir": "coco/train2017",
+            "ann_file": "coco/annotations/instances_train2017.json",
+            "is_train": True,
+        },
+
+        "lvis_grounding_train": {
+            "img_dir": "coco",
+            "ann_file": "coco/annotations/lvis_od_train.json"
+        },
+
+
+        "lvis_val": {
+            "img_dir": "coco",
+            "ann_file": "coco/annotations/lvis_od_val.json"
+        },
+        "coco_2017_train": {
+            "img_dir": "coco/train2017",
+            "ann_file": "coco/annotations/instances_train2017.json"
+        },
+        "coco_2017_val": {
+            "img_dir": "coco/val2017",
+            "ann_file": "coco/annotations/instances_val2017.json"
+        },
+        "coco_2017_test": {
+            "img_dir": "coco/test2017",
+            "ann_file": "coco/annotations/image_info_test-dev2017.json"
+        },
+        "coco_2014_train": {
+            "img_dir": "coco/train2014",
+            "ann_file": "coco/annotations/instances_train2014.json"
+        },
+        "coco_2014_val": {
+            "img_dir": "coco/val2014",
+            "ann_file": "coco/annotations/instances_val2014.json"
+        },
+        "coco_2014_minival": {
+            "img_dir": "coco/val2014",
+            "ann_file": "coco/annotations/instances_minival2014.json"
+        },
+    }
+
+    @staticmethod
+    def set(name, info):
+        DatasetCatalog.DATASETS.update({name: info})
+
+    @staticmethod
+    def get(name):
+
+        if name.endswith('_bg'):
+            attrs = DatasetCatalog.DATASETS[name]
+            data_dir = try_to_find(attrs["ann_file"], return_dir=True)
+            args = dict(
+                root=os.path.join(data_dir, attrs["img_dir"]),
+                ann_file=os.path.join(data_dir, attrs["ann_file"]),
+            )
+            return dict(
+                factory="Background",
+                args=args,
+            )
+        else:
+            if "bing" in name.split("_"):
+                attrs = DatasetCatalog.DATASETS["bing_caption_train"]
+            else:
+                attrs = DatasetCatalog.DATASETS[name]
+
+            if "voc" in name and 'split' in attrs:
+                data_dir = try_to_find(attrs["data_dir"], return_dir=True)
+                args = dict(
+                    data_dir=os.path.join(data_dir, attrs["data_dir"]),
+                    split=attrs["split"],
+                )
+                return dict(
+                    factory="PascalVOCDataset",
+                    args=args,
+                )
+            elif "mixed" in name:
+                vg_img_dir = try_to_find(attrs["vg_img_dir"], return_dir=True)
+                coco_img_dir = try_to_find(attrs["coco_img_dir"], return_dir=True)
+                ann_file = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder_coco=os.path.join(coco_img_dir, attrs["coco_img_dir"]),
+                    img_folder_vg=os.path.join(vg_img_dir, attrs["vg_img_dir"]),
+                    ann_file=os.path.join(ann_file, attrs["ann_file"])
+                )
+                return dict(
+                    factory="MixedDataset",
+                    args=args,
+                )
+            elif "flickr" in name:
+                img_dir = try_to_find(attrs["img_folder"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_folder"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                    is_train=attrs["is_train"]
+                )
+                return dict(
+                    factory="FlickrDataset",
+                    args=args,
+                )
+            elif "refexp" in name:
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                return dict(
+                    factory="RefExpDataset",
+                    args=args,
+                )
+            elif "gqa" in name:
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                return dict(
+                    factory="GQADataset",
+                    args=args,
+                )
+            elif "phrasecut" in name:
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                return dict(
+                    factory="PhrasecutDetection",
+                    args=args,
+                )
+            elif "_caption" in name:
+                yaml_path = try_to_find(attrs["yaml_path"], return_dir=True)
+                if "no_coco" in name:
+                    yaml_name = attrs["yaml_name_no_coco"]
+                else:
+                    yaml_name = attrs["yaml_name"]
+                yaml_file_name = "{}.{}.yaml".format(yaml_name, name.split("_")[2])
+                args = dict(
+                    yaml_file=os.path.join(yaml_path, attrs["yaml_path"], yaml_file_name)
+                )
+                return dict(
+                    factory="CaptionTSV",
+                    args=args,
+                )
+            elif "inferencecap" in name:
+                yaml_file_name = try_to_find(attrs["yaml_path"])
+                args = dict(
+                    yaml_file=yaml_file_name)
+                return dict(
+                    factory="CaptionTSV",
+                    args=args,
+                )
+            elif "pseudo_data" in name:
+                args = dict(
+                    yaml_file=try_to_find(attrs["yaml_path"])
+                )
+                return dict(
+                    factory="PseudoData",
+                    args=args,
+                )
+            elif "_dt" in name:
+                dataset_file = attrs["dataset_file"]
+                yaml_path = try_to_find(attrs["yaml_path"], return_dir=True)
+                args = dict(
+                    name=dataset_file,
+                    yaml_file=os.path.join(yaml_path, attrs["yaml_path"]),
+                )
+                return dict(
+                    factory="CocoDetectionTSV",
+                    args=args,
+                )
+            elif "_odtsv" in name:
+                dataset_file = attrs["dataset_file"]
+                yaml_path = try_to_find(attrs["yaml_path"], return_dir=True)
+                args = dict(
+                    name=dataset_file,
+                    yaml_file=os.path.join(yaml_path, attrs["yaml_path"]),
+                )
+                return dict(
+                    factory="ODTSVDataset",
+                    args=args,
+                )
+            elif "_grounding" in name:
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                return dict(
+                    factory="CocoGrounding",
+                    args=args,
+                )
+            elif "lvis_evaluation" in name:
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                args = dict(
+                    img_folder=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                return dict(
+                    factory="LvisDetection",
+                    args=args,
+                )
+            else:
+                ann_dir = try_to_find(attrs["ann_file"], return_dir=True)
+                img_dir = try_to_find(attrs["img_dir"], return_dir=True)
+                args = dict(
+                    root=os.path.join(img_dir, attrs["img_dir"]),
+                    ann_file=os.path.join(ann_dir, attrs["ann_file"]),
+                )
+                for k, v in attrs.items():
+                    args.update({k: os.path.join(ann_dir, v)})
+                return dict(
+                    factory="COCODataset",
+                    args=args,
+                )
+
+        raise RuntimeError("Dataset not available: {}".format(name))
+
+
+class ModelCatalog(object):
+    S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron"
+    C2_IMAGENET_MODELS = {
+        "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
+        "MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
+        "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
+        "MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
+        "FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
+        "FAIR/20171220/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
+    }
+
+    C2_DETECTRON_SUFFIX = "output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl"
+    C2_DETECTRON_MODELS = {
+        "35857197/e2e_faster_rcnn_R-50-C4_1x": "01_33_49.iAX0mXvW",
+        "35857345/e2e_faster_rcnn_R-50-FPN_1x": "01_36_30.cUF7QR7I",
+        "35857890/e2e_faster_rcnn_R-101-FPN_1x": "01_38_50.sNxI7sX7",
+        "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "06_31_39.5MIHi1fZ",
+        "35858791/e2e_mask_rcnn_R-50-C4_1x": "01_45_57.ZgkA7hPB",
+        "35858933/e2e_mask_rcnn_R-50-FPN_1x": "01_48_14.DzEQe4wC",
+        "35861795/e2e_mask_rcnn_R-101-FPN_1x": "02_31_37.KqyEK4tT",
+        "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "06_35_59.RZotkLKI",
+    }
+
+    @staticmethod
+    def get(name):
+        if name.startswith("Caffe2Detectron/COCO"):
+            return ModelCatalog.get_c2_detectron_12_2017_baselines(name)
+        if name.startswith("ImageNetPretrained"):
+            return ModelCatalog.get_c2_imagenet_pretrained(name)
+        raise RuntimeError("model not present in the catalog {}".format(name))
+
+    @staticmethod
+    def get_c2_imagenet_pretrained(name):
+        prefix = ModelCatalog.S3_C2_DETECTRON_URL
+        name = name[len("ImageNetPretrained/"):]
+        name = ModelCatalog.C2_IMAGENET_MODELS[name]
+        url = "/".join([prefix, name])
+        return url
+
+    @staticmethod
+    def get_c2_detectron_12_2017_baselines(name):
+        # Detectron C2 models are stored following the structure
+        # prefix/<model_id>/2012_2017_baselines/<model_name>.yaml.<signature>/suffix
+        # we use as identifiers in the catalog Caffe2Detectron/COCO/<model_id>/<model_name>
+        prefix = ModelCatalog.S3_C2_DETECTRON_URL
+        suffix = ModelCatalog.C2_DETECTRON_SUFFIX
+        # remove identification prefix
+        name = name[len("Caffe2Detectron/COCO/"):]
+        # split in <model_id> and <model_name>
+        model_id, model_name = name.split("/")
+        # parsing to make it match the url address from the Caffe2 models
+        model_name = "{}.yaml".format(model_name)
+        signature = ModelCatalog.C2_DETECTRON_MODELS[name]
+        unique_name = ".".join([model_name, signature])
+        url = "/".join([prefix, model_id, "12_2017_baselines", unique_name, suffix])
+        return url
diff --git a/maskrcnn_benchmark/csrc/ROIAlign.h b/maskrcnn_benchmark/csrc/ROIAlign.h
new file mode 100644
index 0000000000000000000000000000000000000000..2683dbf52e120eebb7b60bb2257cd3527c5a86c3
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/ROIAlign.h
@@ -0,0 +1,46 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+// Interface for Python
+at::Tensor ROIAlign_forward(const at::Tensor& input,
+                            const at::Tensor& rois,
+                            const float spatial_scale,
+                            const int pooled_height,
+                            const int pooled_width,
+                            const int sampling_ratio) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+}
+
+at::Tensor ROIAlign_backward(const at::Tensor& grad,
+                             const at::Tensor& rois,
+                             const float spatial_scale,
+                             const int pooled_height,
+                             const int pooled_width,
+                             const int batch_size,
+                             const int channels,
+                             const int height,
+                             const int width,
+                             const int sampling_ratio) {
+  if (grad.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/maskrcnn_benchmark/csrc/ROIPool.h b/maskrcnn_benchmark/csrc/ROIPool.h
new file mode 100644
index 0000000000000000000000000000000000000000..9b62b2dcb8f69ac65bc1fdf0eeb5fa556539bc13
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/ROIPool.h
@@ -0,0 +1,48 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+
+std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
+                                const at::Tensor& rois,
+                                const float spatial_scale,
+                                const int pooled_height,
+                                const int pooled_width) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+at::Tensor ROIPool_backward(const at::Tensor& grad,
+                                 const at::Tensor& input,
+                                 const at::Tensor& rois,
+                                 const at::Tensor& argmax,
+                                 const float spatial_scale,
+                                 const int pooled_height,
+                                 const int pooled_width,
+                                 const int batch_size,
+                                 const int channels,
+                                 const int height,
+                                 const int width) {
+  if (grad.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+
diff --git a/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
new file mode 100644
index 0000000000000000000000000000000000000000..e220c12ae558a176f6b4b0a6640e724358f2ecb0
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+// Interface for Python
+at::Tensor SigmoidFocalLoss_forward(
+		const at::Tensor& logits,
+                const at::Tensor& targets,
+		const int num_classes, 
+		const float gamma, 
+		const float alpha) {
+  if (logits.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+at::Tensor SigmoidFocalLoss_backward(
+			     const at::Tensor& logits,
+                             const at::Tensor& targets,
+			     const at::Tensor& d_losses,
+			     const int num_classes,
+			     const float gamma,
+			     const float alpha) {
+  if (logits.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
diff --git a/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0c061351588df7752293ed84bba1c900768e3ab8
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp
@@ -0,0 +1,257 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include "cpu/vision.h"
+
+// implementation taken from Caffe2
+template <typename T>
+struct PreCalc {
+  int pos1;
+  int pos2;
+  int pos3;
+  int pos4;
+  T w1;
+  T w2;
+  T w3;
+  T w4;
+};
+
+template <typename T>
+void pre_calc_for_bilinear_interpolate(
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int iy_upper,
+    const int ix_upper,
+    T roi_start_h,
+    T roi_start_w,
+    T bin_size_h,
+    T bin_size_w,
+    int roi_bin_grid_h,
+    int roi_bin_grid_w,
+    std::vector<PreCalc<T>>& pre_calc) {
+  int pre_calc_index = 0;
+  for (int ph = 0; ph < pooled_height; ph++) {
+    for (int pw = 0; pw < pooled_width; pw++) {
+      for (int iy = 0; iy < iy_upper; iy++) {
+        const T yy = roi_start_h + ph * bin_size_h +
+            static_cast<T>(iy + .5f) * bin_size_h /
+                static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+        for (int ix = 0; ix < ix_upper; ix++) {
+          const T xx = roi_start_w + pw * bin_size_w +
+              static_cast<T>(ix + .5f) * bin_size_w /
+                  static_cast<T>(roi_bin_grid_w);
+
+          T x = xx;
+          T y = yy;
+          // deal with: inverse elements are out of feature map boundary
+          if (y < -1.0 || y > height || x < -1.0 || x > width) {
+            // empty
+            PreCalc<T> pc;
+            pc.pos1 = 0;
+            pc.pos2 = 0;
+            pc.pos3 = 0;
+            pc.pos4 = 0;
+            pc.w1 = 0;
+            pc.w2 = 0;
+            pc.w3 = 0;
+            pc.w4 = 0;
+            pre_calc[pre_calc_index] = pc;
+            pre_calc_index += 1;
+            continue;
+          }
+
+          if (y <= 0) {
+            y = 0;
+          }
+          if (x <= 0) {
+            x = 0;
+          }
+
+          int y_low = (int)y;
+          int x_low = (int)x;
+          int y_high;
+          int x_high;
+
+          if (y_low >= height - 1) {
+            y_high = y_low = height - 1;
+            y = (T)y_low;
+          } else {
+            y_high = y_low + 1;
+          }
+
+          if (x_low >= width - 1) {
+            x_high = x_low = width - 1;
+            x = (T)x_low;
+          } else {
+            x_high = x_low + 1;
+          }
+
+          T ly = y - y_low;
+          T lx = x - x_low;
+          T hy = 1. - ly, hx = 1. - lx;
+          T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+          // save weights and indeces
+          PreCalc<T> pc;
+          pc.pos1 = y_low * width + x_low;
+          pc.pos2 = y_low * width + x_high;
+          pc.pos3 = y_high * width + x_low;
+          pc.pos4 = y_high * width + x_high;
+          pc.w1 = w1;
+          pc.w2 = w2;
+          pc.w3 = w3;
+          pc.w4 = w4;
+          pre_calc[pre_calc_index] = pc;
+
+          pre_calc_index += 1;
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+void ROIAlignForward_cpu_kernel(
+    const int nthreads,
+    const T* bottom_data,
+    const T& spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    const T* bottom_rois,
+    //int roi_cols,
+    T* top_data) {
+  //AT_ASSERT(roi_cols == 4 || roi_cols == 5);
+  int roi_cols = 5;
+
+  int n_rois = nthreads / channels / pooled_width / pooled_height;
+  // (n, c, ph, pw) is an element in the pooled output
+  // can be parallelized using omp
+  // #pragma omp parallel for num_threads(32)
+  for (int n = 0; n < n_rois; n++) {
+    int index_n = n * channels * pooled_width * pooled_height;
+
+    // roi could have 4 or 5 columns
+    const T* offset_bottom_rois = bottom_rois + n * roi_cols;
+    int roi_batch_ind = 0;
+    if (roi_cols == 5) {
+      roi_batch_ind = offset_bottom_rois[0];
+      offset_bottom_rois++;
+    }
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[0] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[1] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[3] * spatial_scale;
+    // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
+    // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
+    // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
+    // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+    // we want to precalculate indeces and weights shared by all chanels,
+    // this is the key point of optimiation
+    std::vector<PreCalc<T>> pre_calc(
+        roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
+    pre_calc_for_bilinear_interpolate(
+        height,
+        width,
+        pooled_height,
+        pooled_width,
+        roi_bin_grid_h,
+        roi_bin_grid_w,
+        roi_start_h,
+        roi_start_w,
+        bin_size_h,
+        bin_size_w,
+        roi_bin_grid_h,
+        roi_bin_grid_w,
+        pre_calc);
+
+      for (int c = 0; c < channels; c++) {
+      int index_n_c = index_n + c * pooled_width * pooled_height;
+      const T* offset_bottom_data =
+          bottom_data + (roi_batch_ind * channels + c) * height * width;
+      int pre_calc_index = 0;
+
+      for (int ph = 0; ph < pooled_height; ph++) {
+        for (int pw = 0; pw < pooled_width; pw++) {
+          int index = index_n_c + ph * pooled_width + pw;
+
+          T output_val = 0.;
+          for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+            for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+              PreCalc<T> pc = pre_calc[pre_calc_index];
+              output_val += pc.w1 * offset_bottom_data[pc.pos1] +
+                  pc.w2 * offset_bottom_data[pc.pos2] +
+                  pc.w3 * offset_bottom_data[pc.pos3] +
+                  pc.w4 * offset_bottom_data[pc.pos4];
+
+              pre_calc_index += 1;
+            }
+          }
+          output_val /= count;
+
+          top_data[index] = output_val;
+        } // for pw
+      } // for ph
+    } // for c
+  } // for n
+}
+
+at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
+                                const at::Tensor& rois,
+                                const float spatial_scale,
+                                const int pooled_height,
+                                const int pooled_width,
+                                const int sampling_ratio) {
+  AT_ASSERTM(!input.device().is_cuda(), "input must be a CPU tensor");
+  AT_ASSERTM(!rois.device().is_cuda(), "rois must be a CPU tensor");
+
+  auto num_rois = rois.size(0);
+  auto channels = input.size(1);
+  auto height = input.size(2);
+  auto width = input.size(3);
+
+  auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
+  auto output_size = num_rois * pooled_height * pooled_width * channels;
+
+  if (output.numel() == 0) {
+    return output;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
+    ROIAlignForward_cpu_kernel<scalar_t>(
+         output_size,
+         input.data_ptr<scalar_t>(),
+         spatial_scale,
+         channels,
+         height,
+         width,
+         pooled_height,
+         pooled_width,
+         sampling_ratio,
+         rois.data_ptr<scalar_t>(),
+         output.data_ptr<scalar_t>());
+  });
+  return output;
+}
diff --git a/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..11b7aa60fdca907352b334f142faadb46d662f99
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp
@@ -0,0 +1,75 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include "cpu/vision.h"
+
+
+template <typename scalar_t>
+at::Tensor nms_cpu_kernel(const at::Tensor& dets,
+                          const at::Tensor& scores,
+                          const float threshold) {
+  AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor");
+  AT_ASSERTM(!scores.device().is_cuda(), "scores must be a CPU tensor");
+  AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
+
+  if (dets.numel() == 0) {
+    return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
+  }
+
+  auto x1_t = dets.select(1, 0).contiguous();
+  auto y1_t = dets.select(1, 1).contiguous();
+  auto x2_t = dets.select(1, 2).contiguous();
+  auto y2_t = dets.select(1, 3).contiguous();
+
+  at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
+
+  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
+
+  auto ndets = dets.size(0);
+  at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
+
+  auto suppressed = suppressed_t.data_ptr<uint8_t>();
+  auto order = order_t.data_ptr<int64_t>();
+  auto x1 = x1_t.data_ptr<scalar_t>();
+  auto y1 = y1_t.data_ptr<scalar_t>();
+  auto x2 = x2_t.data_ptr<scalar_t>();
+  auto y2 = y2_t.data_ptr<scalar_t>();
+  auto areas = areas_t.data_ptr<scalar_t>();
+
+  for (int64_t _i = 0; _i < ndets; _i++) {
+    auto i = order[_i];
+    if (suppressed[i] == 1)
+      continue;
+    auto ix1 = x1[i];
+    auto iy1 = y1[i];
+    auto ix2 = x2[i];
+    auto iy2 = y2[i];
+    auto iarea = areas[i];
+
+    for (int64_t _j = _i + 1; _j < ndets; _j++) {
+      auto j = order[_j];
+      if (suppressed[j] == 1)
+        continue;
+      auto xx1 = std::max(ix1, x1[j]);
+      auto yy1 = std::max(iy1, y1[j]);
+      auto xx2 = std::min(ix2, x2[j]);
+      auto yy2 = std::min(iy2, y2[j]);
+
+      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
+      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
+      auto inter = w * h;
+      auto ovr = inter / (iarea + areas[j] - inter);
+      if (ovr >= threshold)
+        suppressed[j] = 1;
+   }
+  }
+  return at::nonzero(suppressed_t == 0).squeeze(1);
+}
+
+at::Tensor nms_cpu(const at::Tensor& dets,
+               const at::Tensor& scores,
+               const float threshold) {
+  at::Tensor result;
+  AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
+    result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
+  });
+  return result;
+}
diff --git a/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp b/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..423941d71e29f5b9823006d57cdf0088646586ed
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp
@@ -0,0 +1,117 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include "cpu/vision.h"
+
+
+template <typename scalar_t>
+std::pair<at::Tensor, at::Tensor> soft_nms_cpu_kernel(const at::Tensor& dets,
+                                                      const at::Tensor& scores,
+                                                      const float threshold,
+                                                      const float sigma) {
+  AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor");
+  AT_ASSERTM(!scores.device().is_cuda(), "scores must be a CPU tensor");
+  AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
+
+  if (dets.numel() == 0) {
+    return std::make_pair(at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)),
+                          at::empty({0}, scores.options().dtype(at::kFloat).device(at::kCPU)));
+  }
+
+  auto x1_t = dets.select(1, 0).contiguous();
+  auto y1_t = dets.select(1, 1).contiguous();
+  auto x2_t = dets.select(1, 2).contiguous();
+  auto y2_t = dets.select(1, 3).contiguous();
+
+  auto scores_t = scores.clone();
+
+  at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
+  auto ndets = dets.size(0);
+  auto inds_t = at::arange(ndets, dets.options().dtype(at::kLong).device(at::kCPU));
+
+  auto x1 = x1_t.data_ptr<scalar_t>();
+  auto y1 = y1_t.data_ptr<scalar_t>();
+  auto x2 = x2_t.data_ptr<scalar_t>();
+  auto y2 = y2_t.data_ptr<scalar_t>();
+  auto s = scores_t.data_ptr<scalar_t>();
+  auto inds = inds_t.data_ptr<int64_t>();
+  auto areas = areas_t.data_ptr<scalar_t>();
+
+  for (int64_t i = 0; i < ndets; i++) {
+
+    auto ix1 = x1[i];
+    auto iy1 = y1[i];
+    auto ix2 = x2[i];
+    auto iy2 = y2[i];
+    auto is = s[i];
+    auto ii = inds[i];
+    auto iarea = areas[i];
+
+    auto maxpos = scores_t.slice(0, i, ndets).argmax().item<int64_t>() + i;
+
+    // add max box as a detection
+    x1[i] = x1[maxpos];
+    y1[i] = y1[maxpos];
+    x2[i] = x2[maxpos];
+    y2[i] = y2[maxpos];
+    s[i] = s[maxpos];
+    inds[i] = inds[maxpos];
+    areas[i] = areas[maxpos];
+
+    // swap ith box with position of max box
+    x1[maxpos] = ix1;
+    y1[maxpos] = iy1;
+    x2[maxpos] = ix2;
+    y2[maxpos] = iy2;
+    s[maxpos] = is;
+    inds[maxpos] = ii;
+    areas[maxpos] = iarea;
+
+    ix1 = x1[i];
+    iy1 = y1[i];
+    ix2 = x2[i];
+    iy2 = y2[i];
+    iarea = areas[i];
+
+    // NMS iterations, note that ndets changes if detection boxes
+    // fall below threshold
+    for (int64_t j = i + 1; j < ndets; j++) {
+      auto xx1 = std::max(ix1, x1[j]);
+      auto yy1 = std::max(iy1, y1[j]);
+      auto xx2 = std::min(ix2, x2[j]);
+      auto yy2 = std::min(iy2, y2[j]);
+
+      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
+      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
+
+      auto inter = w * h;
+      auto ovr = inter / (iarea + areas[j] - inter);
+
+      s[j] = s[j] * std::exp(- std::pow(ovr, 2.0) / sigma);
+
+      // if box score falls below threshold, discard the box by
+      // swapping with last box update ndets
+      if (s[j] < threshold) {
+        x1[j] = x1[ndets - 1];
+        y1[j] = y1[ndets - 1];
+        x2[j] = x2[ndets - 1];
+        y2[j] = y2[ndets - 1];
+        s[j] = s[ndets - 1];
+        inds[j] = inds[ndets - 1];
+        areas[j] = areas[ndets - 1];
+        j--;
+        ndets--;
+      }
+    }
+  }
+  return std::make_pair(inds_t.slice(0, 0, ndets), scores_t.slice(0, 0, ndets));
+}
+
+std::pair<at::Tensor, at::Tensor> soft_nms_cpu(const at::Tensor& dets,
+                                               const at::Tensor& scores,
+                                               const float threshold,
+                                               const float sigma) {
+  std::pair<at::Tensor, at::Tensor> result;
+  AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] {
+    result = soft_nms_cpu_kernel<scalar_t>(dets, scores, threshold, sigma);
+  });
+  return result;
+}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/csrc/cpu/vision.h b/maskrcnn_benchmark/csrc/cpu/vision.h
new file mode 100644
index 0000000000000000000000000000000000000000..e00ef683150eb9d46d0e4f6a30f55a7230a52e93
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cpu/vision.h
@@ -0,0 +1,22 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include <torch/extension.h>
+
+
+at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
+                                const at::Tensor& rois,
+                                const float spatial_scale,
+                                const int pooled_height,
+                                const int pooled_width,
+                                const int sampling_ratio);
+
+
+at::Tensor nms_cpu(const at::Tensor& dets,
+                   const at::Tensor& scores,
+                   const float threshold);
+
+
+std::pair<at::Tensor, at::Tensor> soft_nms_cpu(const at::Tensor& dets,
+                                               const at::Tensor& scores,
+                                               const float threshold,
+                                               const float sigma);
\ No newline at end of file
diff --git a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9ed1a0adfd841a17d3574dee6ac703820fcfe144
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
@@ -0,0 +1,346 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCAtomics.cuh>
+#include <THC/THCDeviceUtils.cuh>
+
+// TODO make it in a common file
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+
+template <typename T>
+__device__ T bilinear_interpolate(const T* bottom_data,
+    const int height, const int width,
+    T y, T x,
+    const int index /* index for debug only*/) {
+
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    //empty
+    return 0;
+  }
+
+  if (y <= 0) y = 0;
+  if (x <= 0) x = 0;
+
+  int y_low = (int) y;
+  int x_low = (int) x;
+  int y_high;
+  int x_high;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T) y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T) x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+  // do bilinear interpolation
+  T v1 = bottom_data[y_low * width + x_low];
+  T v2 = bottom_data[y_low * width + x_high];
+  T v3 = bottom_data[y_high * width + x_low];
+  T v4 = bottom_data[y_high * width + x_high];
+  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  return val;
+}
+
+template <typename T>
+__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
+    const T spatial_scale, const int channels,
+    const int height, const int width,
+    const int pooled_height, const int pooled_width,
+    const int sampling_ratio,
+    const T* bottom_rois, T* top_data) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+    // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+    // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+    // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+    // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+    T output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix ++)
+      {
+        const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+
+        T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
+        output_val += val;
+      }
+    }
+    output_val /= count;
+
+    top_data[index] = output_val;
+  }
+}
+
+
+template <typename T>
+__device__ void bilinear_interpolate_gradient(
+    const int height, const int width,
+    T y, T x,
+    T & w1, T & w2, T & w3, T & w4,
+    int & x_low, int & x_high, int & y_low, int & y_high,
+    const int index /* index for debug only*/) {
+
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    //empty
+    w1 = w2 = w3 = w4 = 0.;
+    x_low = x_high = y_low = y_high = -1;
+    return;
+  }
+
+  if (y <= 0) y = 0;
+  if (x <= 0) x = 0;
+
+  y_low = (int) y;
+  x_low = (int) x;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T) y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T) x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  // reference in forward
+  // T v1 = bottom_data[y_low * width + x_low];
+  // T v2 = bottom_data[y_low * width + x_high];
+  // T v3 = bottom_data[y_high * width + x_low];
+  // T v4 = bottom_data[y_high * width + x_high];
+  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  return;
+}
+
+template <typename T>
+__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
+    const int num_rois, const T spatial_scale,
+    const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width,
+    const int sampling_ratio,
+    T* bottom_diff,
+    const T* bottom_rois) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+    // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+    // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+    // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+    // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+    int top_offset    = (n * channels + c) * pooled_height * pooled_width;
+    const T* offset_top_diff = top_diff + top_offset;
+    const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix ++)
+      {
+        const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient(height, width, y, x,
+            w1, w2, w3, w4,
+            x_low, x_high, y_low, y_high,
+            index);
+
+        T g1 = top_diff_this_bin * w1 / count;
+        T g2 = top_diff_this_bin * w2 / count;
+        T g3 = top_diff_this_bin * w3 / count;
+        T g4 = top_diff_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
+        {
+          atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
+          atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
+          atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
+          atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
+        } // if
+      } // ix
+    } // iy
+  } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignBackward
+
+
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
+                                 const at::Tensor& rois,
+                                 const float spatial_scale,
+                                 const int pooled_height,
+                                 const int pooled_width,
+                                 const int sampling_ratio) {
+  AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  auto num_rois = rois.size(0);
+  auto channels = input.size(1);
+  auto height = input.size(2);
+  auto width = input.size(3);
+
+  auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
+  auto output_size = num_rois * pooled_height * pooled_width * channels;
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (output.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return output;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
+    RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
+         output_size,
+         input.contiguous().data_ptr<scalar_t>(),
+         spatial_scale,
+         channels,
+         height,
+         width,
+         pooled_height,
+         pooled_width,
+         sampling_ratio,
+         rois.contiguous().data_ptr<scalar_t>(),
+         output.data_ptr<scalar_t>());
+  });
+  THCudaCheck(cudaGetLastError());
+  return output;
+}
+
+// TODO remove the dependency on input and use instead its sizes -> save memory
+at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
+                                  const at::Tensor& rois,
+                                  const float spatial_scale,
+                                  const int pooled_height,
+                                  const int pooled_width,
+                                  const int batch_size,
+                                  const int channels,
+                                  const int height,
+                                  const int width,
+                                  const int sampling_ratio) {
+  AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  auto num_rois = rois.size(0);
+  auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
+
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
+  dim3 block(512);
+
+  // handle possibly empty gradients
+  if (grad.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return grad_input;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] {
+    RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
+         grad.numel(),
+         grad.contiguous().data_ptr<scalar_t>(),
+         num_rois,
+         spatial_scale,
+         channels,
+         height,
+         width,
+         pooled_height,
+         pooled_width,
+         sampling_ratio,
+         grad_input.data_ptr<scalar_t>(),
+         rois.contiguous().data_ptr<scalar_t>());
+  });
+  THCudaCheck(cudaGetLastError());
+  return grad_input;
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..60fc9fbc55956304c7ff6b48cbf3c086029b8354
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
@@ -0,0 +1,202 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCAtomics.cuh>
+#include <THC/THCDeviceUtils.cuh>
+
+
+// TODO make it in a common file
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+
+template <typename T>
+__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
+    const T spatial_scale, const int channels, const int height,
+    const int width, const int pooled_height, const int pooled_width,
+    const T* bottom_rois, T* top_data, int* argmax_data) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+    int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+    int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+    int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+    int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
+    int roi_height = max(roi_end_h - roi_start_h + 1, 1);
+    T bin_size_h = static_cast<T>(roi_height)
+                       / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width)
+                       / static_cast<T>(pooled_width);
+
+    int hstart = static_cast<int>(floor(static_cast<T>(ph)
+                                        * bin_size_h));
+    int wstart = static_cast<int>(floor(static_cast<T>(pw)
+                                        * bin_size_w));
+    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1)
+                                     * bin_size_h));
+    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1)
+                                     * bin_size_w));
+
+    // Add roi offsets and clip to input boundaries
+    hstart = min(max(hstart + roi_start_h, 0), height);
+    hend = min(max(hend + roi_start_h, 0), height);
+    wstart = min(max(wstart + roi_start_w, 0), width);
+    wend = min(max(wend + roi_start_w, 0), width);
+    bool is_empty = (hend <= hstart) || (wend <= wstart);
+
+    // Define an empty pooling region to be zero
+    T maxval = is_empty ? 0 : -FLT_MAX;
+    // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
+    int maxidx = -1;
+    const T* offset_bottom_data =
+        bottom_data + (roi_batch_ind * channels + c) * height * width;
+    for (int h = hstart; h < hend; ++h) {
+      for (int w = wstart; w < wend; ++w) {
+        int bottom_index = h * width + w;
+        if (offset_bottom_data[bottom_index] > maxval) {
+          maxval = offset_bottom_data[bottom_index];
+          maxidx = bottom_index;
+        }
+      }
+    }
+    top_data[index] = maxval;
+    argmax_data[index] = maxidx;
+  }
+}
+
+template <typename T>
+__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff,
+    const int* argmax_data, const int num_rois, const T spatial_scale,
+    const int channels, const int height, const int width,
+    const int pooled_height, const int pooled_width, T* bottom_diff,
+    const T* bottom_rois) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+    int bottom_offset = (roi_batch_ind * channels + c) * height * width;
+    int top_offset    = (n * channels + c) * pooled_height * pooled_width;
+    const T* offset_top_diff = top_diff + top_offset;
+    T* offset_bottom_diff = bottom_diff + bottom_offset;
+    const int* offset_argmax_data = argmax_data + top_offset;
+
+    int argmax = offset_argmax_data[ph * pooled_width + pw];
+    if (argmax != -1) {
+      atomicAdd(
+          offset_bottom_diff + argmax,
+          static_cast<T>(offset_top_diff[ph * pooled_width + pw]));
+
+    }
+  }
+}
+
+std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
+                                const at::Tensor& rois,
+                                const float spatial_scale,
+                                const int pooled_height,
+                                const int pooled_width) {
+  AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  auto num_rois = rois.size(0);
+  auto channels = input.size(1);
+  auto height = input.size(2);
+  auto width = input.size(3);
+
+  auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
+  auto output_size = num_rois * pooled_height * pooled_width * channels;
+  auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));
+
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (output.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return std::make_tuple(output, argmax);
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIPool_forward", [&] {
+    RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
+         output_size,
+         input.contiguous().data_ptr<scalar_t>(),
+         spatial_scale,
+         channels,
+         height,
+         width,
+         pooled_height,
+         pooled_width,
+         rois.contiguous().data_ptr<scalar_t>(),
+         output.data_ptr<scalar_t>(),
+         argmax.data_ptr<int>());
+  });
+  THCudaCheck(cudaGetLastError());
+  return std::make_tuple(output, argmax);
+}
+
+// TODO remove the dependency on input and use instead its sizes -> save memory
+at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
+                                 const at::Tensor& input,
+                                 const at::Tensor& rois,
+                                 const at::Tensor& argmax,
+                                 const float spatial_scale,
+                                 const int pooled_height,
+                                 const int pooled_width,
+                                 const int batch_size,
+                                 const int channels,
+                                 const int height,
+                                 const int width) {
+  AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+  // TODO add more checks
+
+  auto num_rois = rois.size(0);
+  auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
+
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
+  dim3 block(512);
+
+  // handle possibly empty gradients
+  if (grad.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return grad_input;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIPool_backward", [&] {
+    RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
+         grad.numel(),
+         grad.contiguous().data_ptr<scalar_t>(),
+         argmax.data_ptr<int>(),
+         num_rois,
+         spatial_scale,
+         channels,
+         height,
+         width,
+         pooled_height,
+         pooled_width,
+         grad_input.data_ptr<scalar_t>(),
+         rois.contiguous().data_ptr<scalar_t>());
+  });
+  THCudaCheck(cudaGetLastError());
+  return grad_input;
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8aeceae0f825598cd36ea99add8da613c5e2482a
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
@@ -0,0 +1,188 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This file is modified from  https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
+// Cheng-Yang Fu
+// cyfu@cs.unc.edu
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCAtomics.cuh>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <cfloat>
+
+// TODO make it in a common file
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+
+template <typename T>
+__global__ void SigmoidFocalLossForward(const int nthreads, 
+    const T* logits,
+    const int* targets,
+    const int num_classes,
+    const float gamma, 
+    const float alpha,
+    const int num, 
+    T* losses) {
+  CUDA_1D_KERNEL_LOOP(i, nthreads) {
+
+    int n = i / num_classes;
+    int d = i % num_classes; // current class[0~79]; 
+    int t = targets[n]; // target class [1~80];
+
+    // Decide it is positive or negative case. 
+    T c1 = (t == (d+1)); 
+    T c2 = (t>=0 & t != (d+1));
+
+    T zn = (1.0 - alpha);
+    T zp = (alpha);
+
+    // p = 1. / 1. + expf(-x); p = sigmoid(x)
+    T  p = 1. / (1. + expf(-logits[i]));
+
+    // (1-p)**gamma * log(p) where
+    T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
+
+    // p**gamma * log(1-p)
+    T term2 = powf(p, gamma) *
+            (-1. * logits[i] * (logits[i] >= 0) -   
+             logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
+
+    losses[i] = 0.0;
+    losses[i] += -c1 * term1 * zp;
+    losses[i] += -c2 * term2 * zn;
+
+  } // CUDA_1D_KERNEL_LOOP
+} // SigmoidFocalLossForward
+
+
+template <typename T>
+__global__ void SigmoidFocalLossBackward(const int nthreads,
+                const T* logits,
+                const int* targets,
+                const T* d_losses,
+                const int num_classes,
+                const float gamma,
+                const float alpha,
+                const int num,
+                T* d_logits) {
+  CUDA_1D_KERNEL_LOOP(i, nthreads) {
+
+    int n = i / num_classes;
+    int d = i % num_classes; // current class[0~79]; 
+    int t = targets[n]; // target class [1~80], 0 is background;
+
+    // Decide it is positive or negative case. 
+    T c1 = (t == (d+1));
+    T c2 = (t>=0 & t != (d+1));
+
+    T zn = (1.0 - alpha);
+    T zp = (alpha);
+    // p = 1. / 1. + expf(-x); p = sigmoid(x)
+    T  p = 1. / (1. + expf(-logits[i]));
+
+    // (1-p)**g * (1 - p - g*p*log(p)
+    T term1 = powf((1. - p), gamma) *
+                      (1. - p - (p * gamma * logf(max(p, FLT_MIN))));
+
+    // (p**g) * (g*(1-p)*log(1-p) - p)
+    T term2 = powf(p, gamma) *
+                  ((-1. * logits[i] * (logits[i] >= 0) -
+                      logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
+                      (1. - p) * gamma - p);
+    d_logits[i] = 0.0;
+    d_logits[i] += -c1 * term1 * zp;
+    d_logits[i] += -c2 * term2 * zn;
+    d_logits[i] = d_logits[i] * d_losses[i];
+
+  } // CUDA_1D_KERNEL_LOOP
+} // SigmoidFocalLossBackward
+
+
+at::Tensor SigmoidFocalLoss_forward_cuda(
+		const at::Tensor& logits,
+                const at::Tensor& targets,
+		const int num_classes, 
+		const float gamma, 
+		const float alpha) {
+  AT_ASSERTM(logits.device().is_cuda(), "logits must be a CUDA tensor");
+  AT_ASSERTM(targets.device().is_cuda(), "targets must be a CUDA tensor");
+  AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
+
+  const int num_samples = logits.size(0);
+	
+  auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
+  auto losses_size = num_samples * logits.size(1);
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (losses.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return losses;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
+    SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
+         losses_size,
+         logits.contiguous().data_ptr<scalar_t>(),
+	 targets.contiguous().data_ptr<int>(),
+         num_classes,
+	 gamma,
+	 alpha,
+	 num_samples,
+         losses.data_ptr<scalar_t>());
+  });
+  THCudaCheck(cudaGetLastError());
+  return losses;   
+}	
+
+
+at::Tensor SigmoidFocalLoss_backward_cuda(
+		const at::Tensor& logits,
+                const at::Tensor& targets,
+		const at::Tensor& d_losses,
+		const int num_classes, 
+		const float gamma, 
+		const float alpha) {
+  AT_ASSERTM(logits.device().is_cuda(), "logits must be a CUDA tensor");
+  AT_ASSERTM(targets.device().is_cuda(), "targets must be a CUDA tensor");
+  AT_ASSERTM(d_losses.device().is_cuda(), "d_losses must be a CUDA tensor");
+
+  AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
+
+  const int num_samples = logits.size(0);
+  AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
+	
+  auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
+  auto d_logits_size = num_samples * logits.size(1);
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
+  dim3 block(512);
+
+  if (d_logits.numel() == 0) {
+    THCudaCheck(cudaGetLastError());
+    return d_logits;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
+    SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
+         d_logits_size,
+         logits.contiguous().data_ptr<scalar_t>(),
+	 targets.contiguous().data_ptr<int>(),
+	 d_losses.contiguous().data_ptr<scalar_t>(),
+         num_classes,
+	 gamma,
+	 alpha,
+	 num_samples,
+         d_logits.data_ptr<scalar_t>());
+  });
+
+  THCudaCheck(cudaGetLastError());
+  return d_logits;   
+}	
+
diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2cdf8d61957e50d452dd230c97b5754dacd2fa0e
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu
@@ -0,0 +1,691 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+#include <cmath>
+
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor grad_im);
+
+void deformable_col2im_coord(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const int channels, const int height,
+    const int width, const int ksize_h, const int ksize_w, const int pad_h,
+    const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int parallel_imgs,
+    const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+    const at::Tensor data_im, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+    const at::Tensor data_col, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im,
+    const int width_im, const int height_col, const int width_col,
+    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w, const int dilation_h,
+    const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+    at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+                 int padW, int dilationH, int dilationW, int group,
+                 int deformable_group) 
+{
+  TORCH_CHECK(weight.ndimension() == 4,
+           "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+           "but got: %s",
+           weight.ndimension());
+
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  TORCH_CHECK(kW > 0 && kH > 0,
+           "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+           kW);
+
+  TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+           "kernel size should be consistent with weight, ",
+           "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+           kW, weight.size(2), weight.size(3));
+
+  TORCH_CHECK(dW > 0 && dH > 0,
+           "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+  TORCH_CHECK(
+      dilationW > 0 && dilationH > 0,
+      "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+      dilationH, dilationW);
+
+  int ndim = input.ndimension();
+  int dimf = 0;
+  int dimh = 1;
+  int dimw = 2;
+
+  if (ndim == 4) {
+    dimf++;
+    dimh++;
+    dimw++;
+  }
+
+  TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+           ndim);
+
+  long nInputPlane = weight.size(1) * group;
+  long inputHeight = input.size(dimh);
+  long inputWidth = input.size(dimw);
+  long nOutputPlane = weight.size(0);
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+  TORCH_CHECK(nInputPlane % deformable_group == 0,
+           "input channels must divide deformable group size");
+
+  if (outputWidth < 1 || outputHeight < 1)
+    AT_ERROR(
+        "Given input size: (%ld x %ld x %ld). "
+        "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+        outputWidth);
+
+  TORCH_CHECK(input.size(1) == nInputPlane,
+           "invalid number of input planes, expected: %d, but got: %d",
+           nInputPlane, input.size(1));
+
+  TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+           "input image is smaller than kernel");
+
+  TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+           "invalid spatial size of offset, expected height: %d width: %d, but "
+           "got height: %d width: %d",
+           outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+  TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+           "invalid number of channels of offset");
+
+  if (gradOutput != NULL) {
+    TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+             "invalid number of gradOutput planes, expected: %d, but got: %d",
+             nOutputPlane, gradOutput->size(dimf));
+
+    TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+              gradOutput->size(dimw) == outputWidth),
+             "invalid size of gradOutput, expected height: %d width: %d , but "
+             "got height: %d width: %d",
+             outputHeight, outputWidth, gradOutput->size(dimh),
+             gradOutput->size(dimw));
+  }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step) 
+{
+  // todo: resize columns to include im2col: done
+  // todo: add im2col_step as input
+  // todo: add new output buffer and transpose it to output (or directly
+  // transpose output) todo: possibly change data indexing because of
+  // parallel_imgs
+
+  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input.unsqueeze_(0);
+    offset.unsqueeze_(0);
+  }
+
+  // todo: assert batchsize dividable by im2col_step
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+                        outputHeight, outputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+    ones = at::ones({outputHeight, outputWidth}, input.options());
+  }
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  at::Tensor output_buffer =
+      at::zeros({batchSize / im2col_step, nOutputPlane,
+                 im2col_step * outputHeight, outputWidth},
+                output.options());
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), group, output_buffer.size(1) / group,
+       output_buffer.size(2), output_buffer.size(3)});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      output_buffer[elt][g] = output_buffer[elt][g]
+                                  .flatten(1)
+                                  .addmm_(weight[g].flatten(1), columns[g])
+                                  .view_as(output_buffer[elt][g]);
+    }
+  }
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+       output_buffer.size(3), output_buffer.size(4)});
+
+  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+                                      im2col_step, outputHeight, outputWidth});
+  output_buffer.transpose_(1, 2);
+  output.copy_(output_buffer);
+  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    output = output.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step) 
+{
+  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view({1, input.size(0), input.size(1), input.size(2)});
+    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  // change order of grad output
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                              inputHeight, inputWidth});
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+                                deformable_group * 2 * kH * kW, outputHeight,
+                                outputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    // divide into groups
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), group, gradOutput.size(1) / group,
+         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+                            dilationH, dilationW, im2col_step, deformable_group,
+                            gradOffset[elt]);
+
+    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, gradInput[elt]);
+  }
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  gradOffset = gradOffset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+    gradOffset =
+        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step) 
+{
+  // todo: transpose and reshape outGrad
+  // todo: reshape columns
+  // todo: add im2col_step as input
+
+  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+              padW, dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view(
+        at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = gradWeight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+                             outputHeight, outputWidth});
+  gradOutputBuffer.copy_(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+                             im2col_step * outputHeight, outputWidth});
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    // divide into group
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    gradWeight =
+        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+                         gradWeight.size(2), gradWeight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      gradWeight[g] = gradWeight[g]
+                          .flatten(1)
+                          .addmm_(gradOutputBuffer[elt][g].flatten(1),
+                                  columns[g].transpose(1, 0), 1.0, scale)
+                          .view_as(gradWeight[g]);
+    }
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0),
+         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+                                  gradWeight.size(2), gradWeight.size(3),
+                                  gradWeight.size(4)});
+  }
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+  }
+
+  return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_out = weight.size(0);
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  // resize output
+  output = output.view({batch, channels_out, height_out, width_out}).zero_();
+  // resize temporary columns
+  columns =
+      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+                input.options());
+
+  output = output.view({output.size(0), group, output.size(1) / group,
+                        output.size(2), output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    // divide into group
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+    for (int g = 0; g < group; g++) {
+      output[b][g] = output[b][g]
+                         .flatten(1)
+                         .addmm_(weight[g].flatten(1), columns[g])
+                         .view_as(output[b][g]);
+    }
+
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+  }
+
+  output = output.view({output.size(0), output.size(1) * output.size(2),
+                        output.size(3), output.size(4)});
+
+  if (with_bias) {
+    output += bias.view({1, bias.size(0), 1, 1});
+  }
+}
+
+void modulated_deform_conv_cuda_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  grad_input = grad_input.view({batch, channels, height, width});
+  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+                      input.options());
+
+  grad_output =
+      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+                        grad_output.size(2), grad_output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    // divide int group
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+
+    // gradient w.r.t. input coordinate data
+    modulated_deformable_col2im_coord_cuda(
+        columns, input[b], offset[b], mask[b], 1, channels, height, width,
+        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+        grad_mask[b]);
+    // gradient w.r.t. input data
+    modulated_deformable_col2im_cuda(
+        columns, offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+    // gradient w.r.t. weight, dWeight should accumulate across the batch and
+    // group
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+                                    grad_weight.size(1), grad_weight.size(2),
+                                    grad_weight.size(3)});
+    if (with_bias)
+      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+    for (int g = 0; g < group; g++) {
+      grad_weight[g] =
+          grad_weight[g]
+              .flatten(1)
+              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+              .view_as(grad_weight[g]);
+      if (with_bias) {
+        grad_bias[g] =
+            grad_bias[g]
+                .view({-1, 1})
+                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+                .view(-1);
+      }
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+                                    grad_weight.size(2), grad_weight.size(3),
+                                    grad_weight.size(4)});
+    if (with_bias)
+      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+  }
+  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+                                  grad_output.size(2), grad_output.size(3),
+                                  grad_output.size(4)});
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ee15810103a4edaf213abdb222a70249d622c0f9
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu
@@ -0,0 +1,874 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+
+#include <ATen/ATen.h>
+#include <THC/THCAtomics.cuh>
+#include <stdio.h>
+#include <math.h>
+#include <float.h>
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n)                                 \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+inline int GET_BLOCKS(const int N)
+{
+  return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+/*
+const int CUDA_NUM_THREADS = 1024;
+
+inline int GET_BLOCKS(const int N)
+{
+  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}*/
+
+template <typename scalar_t>
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+                                               const int height, const int width, scalar_t h, scalar_t w)
+{
+
+  int h_low = floor(h);
+  int w_low = floor(w);
+  int h_high = h_low + 1;
+  int w_high = w_low + 1;
+
+  scalar_t lh = h - h_low;
+  scalar_t lw = w - w_low;
+  scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+    v1 = bottom_data[h_low * data_width + w_low];
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+    v2 = bottom_data[h_low * data_width + w_high];
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+    v3 = bottom_data[h_high * data_width + w_low];
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+    v4 = bottom_data[h_high * data_width + w_high];
+
+  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+template <typename scalar_t>
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                        const int h, const int w, const int height, const int width)
+{
+
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+  if (h == argmax_h_low && w == argmax_w_low)
+    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+  if (h == argmax_h_low && w == argmax_w_high)
+    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+  if (h == argmax_h_high && w == argmax_w_low)
+    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+  if (h == argmax_h_high && w == argmax_w_high)
+    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+  return weight;
+}
+
+template <typename scalar_t>
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                          const int height, const int width, const scalar_t *im_data,
+                                          const int data_width, const int bp_dir)
+{
+
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+
+  if (bp_dir == 0)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+  else if (bp_dir == 1)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+
+  return weight;
+}
+
+template <typename scalar_t>
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+                                             const int height, const int width, const int kernel_h, const int kernel_w,
+                                             const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+                                             const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+                                             const int batch_size, const int num_channels, const int deformable_group,
+                                             const int height_col, const int width_col,
+                                             scalar_t *data_col)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    // index index of output matrix
+    const int w_col = index % width_col;
+    const int h_col = (index / width_col) % height_col;
+    const int b_col = (index / width_col / height_col) % batch_size;
+    const int c_im = (index / width_col / height_col) / batch_size;
+    const int c_col = c_im * kernel_h * kernel_w;
+
+    // compute deformable group index
+    const int deformable_group_index = c_im / channel_per_deformable_group;
+
+    const int h_in = h_col * stride_h - pad_h;
+    const int w_in = w_col * stride_w - pad_w;
+    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+    //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+    for (int i = 0; i < kernel_h; ++i)
+    {
+      for (int j = 0; j < kernel_w; ++j)
+      {
+        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+        scalar_t val = static_cast<scalar_t>(0);
+        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+        {
+          //const scalar_t map_h = i * dilation_h + offset_h;
+          //const scalar_t map_w = j * dilation_w + offset_w;
+          //const int cur_height = height - h_in;
+          //const int cur_width = width - w_in;
+          //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+        }
+        *data_col_ptr = val;
+        data_col_ptr += batch_size * height_col * width_col;
+      }
+    }
+  }
+}
+
+void deformable_im2col(
+    const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+    const int height, const int width, const int ksize_h, const int ksize_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int parallel_imgs,
+    const int deformable_group, at::Tensor data_col)
+{
+  // num_axes should be smaller than block size
+  // todo: check parallel_imgs is correctly passed in
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = channels * height_col * width_col * parallel_imgs;
+  int channel_per_deformable_group = channels / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+
+        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+            channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+            height_col, width_col, data_col_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+  }
+}
+
+template <typename scalar_t>
+__global__ void deformable_col2im_gpu_kernel(
+    const int n, const scalar_t *data_col, const scalar_t *data_offset,
+    const int channels, const int height, const int width,
+    const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int channel_per_deformable_group,
+    const int batch_size, const int deformable_group,
+    const int height_col, const int width_col,
+    scalar_t *grad_im)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / channel_per_deformable_group;
+
+    int w_out = index % width_col;
+    int h_out = (index / width_col) % height_col;
+    int b = (index / width_col / height_col) % batch_size;
+    int w_in = w_out * stride_w - pad_w;
+    int h_in = h_out * stride_h - pad_h;
+
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+                                                        2 * kernel_h * kernel_w * height_col * width_col;
+    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+    const scalar_t cur_top_grad = data_col[index];
+    const int cur_h = (int)cur_inv_h_data;
+    const int cur_w = (int)cur_inv_w_data;
+    for (int dy = -2; dy <= 2; dy++)
+    {
+      for (int dx = -2; dx <= 2; dx++)
+      {
+        if (cur_h + dy >= 0 && cur_h + dy < height &&
+            cur_w + dx >= 0 && cur_w + dx < width &&
+            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+        {
+          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+          scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+        }
+      }
+    }
+  }
+}
+
+void deformable_col2im(
+    const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+    const int height, const int width, const int ksize_h,
+    const int ksize_w, const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int parallel_imgs, const int deformable_group,
+    at::Tensor grad_im)
+{
+
+  // todo: make sure parallel_imgs is passed in correctly
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+  int channel_per_deformable_group = channels / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+
+        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+            ksize_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+  }
+}
+
+template <typename scalar_t>
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+                                                   const scalar_t *data_im, const scalar_t *data_offset,
+                                                   const int channels, const int height, const int width,
+                                                   const int kernel_h, const int kernel_w,
+                                                   const int pad_h, const int pad_w,
+                                                   const int stride_h, const int stride_w,
+                                                   const int dilation_h, const int dilation_w,
+                                                   const int channel_per_deformable_group,
+                                                   const int batch_size, const int offset_channels, const int deformable_group,
+                                                   const int height_col, const int width_col, scalar_t *grad_offset)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    scalar_t val = 0;
+    int w = index % width_col;
+    int h = (index / width_col) % height_col;
+    int c = (index / width_col / height_col) % offset_channels;
+    int b = (index / width_col / height_col) / offset_channels;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+    const int col_step = kernel_h * kernel_w;
+    int cnt = 0;
+    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+                                                  batch_size * width_col * height_col;
+    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+                                                channel_per_deformable_group / kernel_h / kernel_w * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+                                                        kernel_h * kernel_w * height_col * width_col;
+
+    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+    {
+      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+      const int bp_dir = offset_c % 2;
+
+      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+      int w_out = col_pos % width_col;
+      int h_out = (col_pos / width_col) % height_col;
+      int w_in = w_out * stride_w - pad_w;
+      int h_in = h_out * stride_h - pad_h;
+      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+      {
+        inv_h = inv_w = -2;
+      }
+      const scalar_t weight = get_coordinate_weight(
+          inv_h, inv_w,
+          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+      val += weight * data_col_ptr[col_pos];
+      cnt += 1;
+    }
+
+    grad_offset[index] = val;
+  }
+}
+
+void deformable_col2im_coord(
+    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+    const int channels, const int height, const int width, const int ksize_h,
+    const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+    const int stride_w, const int dilation_h, const int dilation_w,
+    const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+  int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+
+        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+            ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+            height_col, width_col, grad_offset_);
+      }));
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+                                         const int height, const int width, scalar_t h, scalar_t w)
+{
+  int h_low = floor(h);
+  int w_low = floor(w);
+  int h_high = h_low + 1;
+  int w_high = w_low + 1;
+
+  scalar_t lh = h - h_low;
+  scalar_t lw = w - w_low;
+  scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+    v1 = bottom_data[h_low * data_width + w_low];
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+    v2 = bottom_data[h_low * data_width + w_high];
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+    v3 = bottom_data[h_high * data_width + w_low];
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+    v4 = bottom_data[h_high * data_width + w_high];
+
+  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                             const int h, const int w, const int height, const int width)
+{
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+  if (h == argmax_h_low && w == argmax_w_low)
+    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+  if (h == argmax_h_low && w == argmax_w_high)
+    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+  if (h == argmax_h_high && w == argmax_w_low)
+    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+  if (h == argmax_h_high && w == argmax_w_high)
+    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+  return weight;
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                               const int height, const int width, const scalar_t *im_data,
+                                               const int data_width, const int bp_dir)
+{
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+
+  if (bp_dir == 0)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+  else if (bp_dir == 1)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+
+  return weight;
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+                                                       const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+                                                       const int height, const int width, const int kernel_h, const int kernel_w,
+                                                       const int pad_h, const int pad_w,
+                                                       const int stride_h, const int stride_w,
+                                                       const int dilation_h, const int dilation_w,
+                                                       const int channel_per_deformable_group,
+                                                       const int batch_size, const int num_channels, const int deformable_group,
+                                                       const int height_col, const int width_col,
+                                                       scalar_t *data_col)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    // index index of output matrix
+    const int w_col = index % width_col;
+    const int h_col = (index / width_col) % height_col;
+    const int b_col = (index / width_col / height_col) % batch_size;
+    const int c_im = (index / width_col / height_col) / batch_size;
+    const int c_col = c_im * kernel_h * kernel_w;
+
+    // compute deformable group index
+    const int deformable_group_index = c_im / channel_per_deformable_group;
+
+    const int h_in = h_col * stride_h - pad_h;
+    const int w_in = w_col * stride_w - pad_w;
+
+    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+    const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+    for (int i = 0; i < kernel_h; ++i)
+    {
+      for (int j = 0; j < kernel_w; ++j)
+      {
+        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+        const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+        scalar_t val = static_cast<scalar_t>(0);
+        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+        {
+          //const float map_h = i * dilation_h + offset_h;
+          //const float map_w = j * dilation_w + offset_w;
+          //const int cur_height = height - h_in;
+          //const int cur_width = width - w_in;
+          //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+          val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+        }
+        *data_col_ptr = val * mask;
+        data_col_ptr += batch_size * height_col * width_col;
+        //data_col_ptr += height_col * width_col;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+                                                       const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+                                                       const int channels, const int height, const int width,
+                                                       const int kernel_h, const int kernel_w,
+                                                       const int pad_h, const int pad_w,
+                                                       const int stride_h, const int stride_w,
+                                                       const int dilation_h, const int dilation_w,
+                                                       const int channel_per_deformable_group,
+                                                       const int batch_size, const int deformable_group,
+                                                       const int height_col, const int width_col,
+                                                       scalar_t *grad_im)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / channel_per_deformable_group;
+
+    int w_out = index % width_col;
+    int h_out = (index / width_col) % height_col;
+    int b = (index / width_col / height_col) % batch_size;
+    int w_in = w_out * stride_w - pad_w;
+    int h_in = h_out * stride_h - pad_h;
+
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+    const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+    const scalar_t cur_top_grad = data_col[index] * mask;
+    const int cur_h = (int)cur_inv_h_data;
+    const int cur_w = (int)cur_inv_w_data;
+    for (int dy = -2; dy <= 2; dy++)
+    {
+      for (int dx = -2; dx <= 2; dx++)
+      {
+        if (cur_h + dy >= 0 && cur_h + dy < height &&
+            cur_w + dx >= 0 && cur_w + dx < width &&
+            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+        {
+          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+          scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+                                                             const scalar_t *data_col, const scalar_t *data_im,
+                                                             const scalar_t *data_offset, const scalar_t *data_mask,
+                                                             const int channels, const int height, const int width,
+                                                             const int kernel_h, const int kernel_w,
+                                                             const int pad_h, const int pad_w,
+                                                             const int stride_h, const int stride_w,
+                                                             const int dilation_h, const int dilation_w,
+                                                             const int channel_per_deformable_group,
+                                                             const int batch_size, const int offset_channels, const int deformable_group,
+                                                             const int height_col, const int width_col,
+                                                             scalar_t *grad_offset, scalar_t *grad_mask)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    scalar_t val = 0, mval = 0;
+    int w = index % width_col;
+    int h = (index / width_col) % height_col;
+    int c = (index / width_col / height_col) % offset_channels;
+    int b = (index / width_col / height_col) / offset_channels;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+    const int col_step = kernel_h * kernel_w;
+    int cnt = 0;
+    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+    {
+      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+      const int bp_dir = offset_c % 2;
+
+      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+      int w_out = col_pos % width_col;
+      int h_out = (col_pos / width_col) % height_col;
+      int w_in = w_out * stride_w - pad_w;
+      int h_in = h_out * stride_h - pad_h;
+      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+      const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+      {
+        inv_h = inv_w = -2;
+      }
+      else
+      {
+        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+      }
+      const scalar_t weight = dmcn_get_coordinate_weight(
+          inv_h, inv_w,
+          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+      val += weight * data_col_ptr[col_pos] * mask;
+      cnt += 1;
+    }
+    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+    grad_offset[index] = val;
+    if (offset_c % 2 == 0)
+      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+  }
+}
+
+void modulated_deformable_im2col_cuda(
+    const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group, at::Tensor data_col)
+{
+  // num_axes should be smaller than block size
+  const int channel_per_deformable_group = channels / deformable_group;
+  const int num_kernels = channels * batch_size * height_col * width_col;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+
+        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, channels, deformable_group, height_col, width_col, data_col_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
+
+void modulated_deformable_col2im_cuda(
+    const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group, at::Tensor grad_im)
+{
+
+  const int channel_per_deformable_group = channels / deformable_group;
+  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+
+        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+            kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, deformable_group, height_col, width_col, grad_im_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group,
+    at::Tensor grad_offset, at::Tensor grad_mask)
+{
+  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+        scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
+
+        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+            num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+            grad_offset_, grad_mask_);
+      }));
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bbe22d77b49be70f174ae3f17647b09968358255
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu
@@ -0,0 +1,87 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
+
+// based on
+// author: Charles Shang
+// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+#include <cmath>
+
+
+void DeformablePSROIPoolForward(
+    const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
+    at::Tensor out, at::Tensor top_count, const int batch, const int channels,
+    const int height, const int width, const int num_bbox,
+    const int channels_trans, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std);
+
+void DeformablePSROIPoolBackwardAcc(
+    const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
+    const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
+    at::Tensor trans_grad, const int batch, const int channels,
+    const int height, const int width, const int num_bbox,
+    const int channels_trans, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std);
+
+void deform_psroi_pooling_cuda_forward(
+    at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
+    at::Tensor top_count, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+  const int channels_trans = no_trans ? 2 : trans.size(1);
+
+  const int num_bbox = bbox.size(0);
+  if (num_bbox != out.size(0))
+    AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+             out.size(0), num_bbox);
+
+  DeformablePSROIPoolForward(
+      input, bbox, trans, out, top_count, batch, channels, height, width,
+      num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
+      pooled_size, part_size, sample_per_part, trans_std);
+}
+
+void deform_psroi_pooling_cuda_backward(
+    at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
+    at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
+    const int no_trans, const float spatial_scale, const int output_dim,
+    const int group_size, const int pooled_size, const int part_size,
+    const int sample_per_part, const float trans_std) 
+{
+  TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+  const int channels_trans = no_trans ? 2 : trans.size(1);
+
+  const int num_bbox = bbox.size(0);
+  if (num_bbox != out_grad.size(0))
+    AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+             out_grad.size(0), num_bbox);
+
+  DeformablePSROIPoolBackwardAcc(
+      out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
+      channels, height, width, num_bbox, channels_trans, no_trans,
+      spatial_scale, output_dim, group_size, pooled_size, part_size,
+      sample_per_part, trans_std);
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3f6c4cb22f6ecbae242e21c9530f474e709c6e90
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu
@@ -0,0 +1,365 @@
+/*!
+ * Copyright (c) 2017 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file deformable_psroi_pooling.cu
+ * \brief
+ * \author Yi Li, Guodong Zhang, Jifeng Dai
+*/
+/***************** Adapted by Charles Shang *********************/
+// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
+
+
+#include <ATen/ATen.h>
+#include <THC/THCAtomics.cuh>
+#include <stdio.h>
+#include <math.h>
+#include <algorithm>
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n)                        \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+       i < (n);                                       \
+       i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N)
+{
+  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+template <typename scalar_t>
+__device__ scalar_t bilinear_interp(
+    const scalar_t *data,
+    const scalar_t x,
+    const scalar_t y,
+    const int width,
+    const int height)
+{
+  int x1 = floor(x);
+  int x2 = ceil(x);
+  int y1 = floor(y);
+  int y2 = ceil(y);
+  scalar_t dist_x = (scalar_t)(x - x1);
+  scalar_t dist_y = (scalar_t)(y - y1);
+  scalar_t value11 = data[y1 * width + x1];
+  scalar_t value12 = data[y2 * width + x1];
+  scalar_t value21 = data[y1 * width + x2];
+  scalar_t value22 = data[y2 * width + x2];
+  scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
+  return value;
+}
+
+template <typename scalar_t>
+__global__ void DeformablePSROIPoolForwardKernel(
+    const int count,
+    const scalar_t *bottom_data,
+    const scalar_t spatial_scale,
+    const int channels,
+    const int height, const int width,
+    const int pooled_height, const int pooled_width,
+    const scalar_t *bottom_rois, const scalar_t *bottom_trans,
+    const int no_trans,
+    const scalar_t trans_std,
+    const int sample_per_part,
+    const int output_dim,
+    const int group_size,
+    const int part_size,
+    const int num_classes,
+    const int channels_each_class,
+    scalar_t *top_data,
+    scalar_t *top_count)
+{
+  CUDA_KERNEL_LOOP(index, count)
+  {
+    // The output is in order (n, ctop, ph, pw)
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int ctop = (index / pooled_width / pooled_height) % output_dim;
+    int n = index / pooled_width / pooled_height / output_dim;
+
+    // [start, end) interval for spatial sampling
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+    scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+    scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+    scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+    scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+    // Force too small ROIs to be 1x1
+    scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+    scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+    // Compute w and h at bottom
+    scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+    scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+    scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+    scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+    int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+    int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+    int class_id = ctop / channels_each_class;
+    scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+    scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+    scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+    wstart += trans_x * roi_width;
+    scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+    hstart += trans_y * roi_height;
+
+    scalar_t sum = 0;
+    int count = 0;
+    int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+    int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+    gw = min(max(gw, 0), group_size - 1);
+    gh = min(max(gh, 0), group_size - 1);
+
+    const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
+    for (int ih = 0; ih < sample_per_part; ih++)
+    {
+      for (int iw = 0; iw < sample_per_part; iw++)
+      {
+        scalar_t w = wstart + iw * sub_bin_size_w;
+        scalar_t h = hstart + ih * sub_bin_size_h;
+        // bilinear interpolation
+        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+        {
+          continue;
+        }
+        w = min(max(w, 0.), width - 1.);
+        h = min(max(h, 0.), height - 1.);
+        int c = (ctop * group_size + gh) * group_size + gw;
+        scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
+        sum += val;
+        count++;
+      }
+    }
+    top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
+    top_count[index] = count;
+  }
+}
+
+template <typename scalar_t>
+__global__ void DeformablePSROIPoolBackwardAccKernel(
+    const int count,
+    const scalar_t *top_diff,
+    const scalar_t *top_count,
+    const int num_rois,
+    const scalar_t spatial_scale,
+    const int channels,
+    const int height, const int width,
+    const int pooled_height, const int pooled_width,
+    const int output_dim,
+    scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
+    const scalar_t *bottom_data,
+    const scalar_t *bottom_rois,
+    const scalar_t *bottom_trans,
+    const int no_trans,
+    const scalar_t trans_std,
+    const int sample_per_part,
+    const int group_size,
+    const int part_size,
+    const int num_classes,
+    const int channels_each_class)
+{
+  CUDA_KERNEL_LOOP(index, count)
+  {
+    // The output is in order (n, ctop, ph, pw)
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int ctop = (index / pooled_width / pooled_height) % output_dim;
+    int n = index / pooled_width / pooled_height / output_dim;
+
+    // [start, end) interval for spatial sampling
+    const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+    scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+    scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+    scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+    scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+    // Force too small ROIs to be 1x1
+    scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+    scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+    // Compute w and h at bottom
+    scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+    scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+    scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+    scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+    int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+    int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+    int class_id = ctop / channels_each_class;
+    scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+    scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+    scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+    wstart += trans_x * roi_width;
+    scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+    hstart += trans_y * roi_height;
+
+    if (top_count[index] <= 0)
+    {
+      continue;
+    }
+    scalar_t diff_val = top_diff[index] / top_count[index];
+    const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
+    scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
+    int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+    int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+    gw = min(max(gw, 0), group_size - 1);
+    gh = min(max(gh, 0), group_size - 1);
+
+    for (int ih = 0; ih < sample_per_part; ih++)
+    {
+      for (int iw = 0; iw < sample_per_part; iw++)
+      {
+        scalar_t w = wstart + iw * sub_bin_size_w;
+        scalar_t h = hstart + ih * sub_bin_size_h;
+        // bilinear interpolation
+        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+        {
+          continue;
+        }
+        w = min(max(w, 0.), width - 1.);
+        h = min(max(h, 0.), height - 1.);
+        int c = (ctop * group_size + gh) * group_size + gw;
+        // backward on feature
+        int x0 = floor(w);
+        int x1 = ceil(w);
+        int y0 = floor(h);
+        int y1 = ceil(h);
+        scalar_t dist_x = w - x0, dist_y = h - y0;
+        scalar_t q00 = (1 - dist_x) * (1 - dist_y);
+        scalar_t q01 = (1 - dist_x) * dist_y;
+        scalar_t q10 = dist_x * (1 - dist_y);
+        scalar_t q11 = dist_x * dist_y;
+        int bottom_index_base = c * height * width;
+        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
+        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
+        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
+        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
+
+        if (no_trans)
+        {
+          continue;
+        }
+        scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+        scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+        scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+        scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+        scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
+        diff_x *= roi_width;
+        scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
+        diff_y *= roi_height;
+
+        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
+        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
+      }
+    }
+  }
+}
+
+void DeformablePSROIPoolForward(const at::Tensor data,
+                                const at::Tensor bbox,
+                                const at::Tensor trans,
+                                at::Tensor out,
+                                at::Tensor top_count,
+                                const int batch,
+                                const int channels,
+                                const int height,
+                                const int width,
+                                const int num_bbox,
+                                const int channels_trans,
+                                const int no_trans,
+                                const float spatial_scale,
+                                const int output_dim,
+                                const int group_size,
+                                const int pooled_size,
+                                const int part_size,
+                                const int sample_per_part,
+                                const float trans_std)
+{
+  const int pooled_height = pooled_size;
+  const int pooled_width = pooled_size;
+  const int count = num_bbox * output_dim * pooled_height * pooled_width;
+  const int num_classes = no_trans ? 1 : channels_trans / 2;
+  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data.scalar_type(), "deformable_psroi_pool_forward", ([&] {
+        const scalar_t *bottom_data = data.data_ptr<scalar_t>();
+        const scalar_t *bottom_rois = bbox.data_ptr<scalar_t>();
+        const scalar_t *bottom_trans = no_trans ? NULL : trans.data_ptr<scalar_t>();
+        scalar_t *top_data = out.data_ptr<scalar_t>();
+        scalar_t *top_count_data = top_count.data_ptr<scalar_t>();
+
+        DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
+            count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
+            bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
+            group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+  }
+}
+
+void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
+                                    const at::Tensor data,
+                                    const at::Tensor bbox,
+                                    const at::Tensor trans,
+                                    const at::Tensor top_count,
+                                    at::Tensor in_grad,
+                                    at::Tensor trans_grad,
+                                    const int batch,
+                                    const int channels,
+                                    const int height,
+                                    const int width,
+                                    const int num_bbox,
+                                    const int channels_trans,
+                                    const int no_trans,
+                                    const float spatial_scale,
+                                    const int output_dim,
+                                    const int group_size,
+                                    const int pooled_size,
+                                    const int part_size,
+                                    const int sample_per_part,
+                                    const float trans_std)
+{
+  // LOG(INFO) << "DeformablePSROIPoolBackward";
+  const int num_rois = num_bbox;
+  const int pooled_height = pooled_size;
+  const int pooled_width = pooled_size;
+  const int count = num_bbox * output_dim * pooled_height * pooled_width;
+  const int num_classes = no_trans ? 1 : channels_trans / 2;
+  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] {
+        const scalar_t *top_diff = out_grad.data_ptr<scalar_t>();
+        const scalar_t *bottom_data = data.data_ptr<scalar_t>();
+        const scalar_t *bottom_rois = bbox.data_ptr<scalar_t>();
+        const scalar_t *bottom_trans = no_trans ? NULL : trans.data_ptr<scalar_t>();
+        scalar_t *bottom_data_diff = in_grad.data_ptr<scalar_t>();
+        scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data_ptr<scalar_t>();
+        const scalar_t *top_count_data = top_count.data_ptr<scalar_t>();
+
+        DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
+            count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
+            pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
+            bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
+            group_size, part_size, num_classes, channels_each_class);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+  }
+}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/csrc/cuda/ml_nms.cu b/maskrcnn_benchmark/csrc/cuda/ml_nms.cu
new file mode 100644
index 0000000000000000000000000000000000000000..cd958a0899a9e3adc69ca053170beb2b34fbd8ef
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/ml_nms.cu
@@ -0,0 +1,136 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+
+int const threadsPerBlock = sizeof(unsigned long long) * 8;
+
+__device__ inline float devIoU(float const * const a, float const * const b) {
+  if (a[5] != b[5]) {
+    return 0.0;
+  }
+  float left = max(a[0], b[0]), right = min(a[2], b[2]);
+  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
+  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+  float interS = width * height;
+  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
+  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+  return interS / (Sa + Sb - interS);
+}
+
+__global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh,
+                           const float *dev_boxes, unsigned long long *dev_mask) {
+  const int row_start = blockIdx.y;
+  const int col_start = blockIdx.x;
+
+  // if (row_start > col_start) return;
+
+  const int row_size =
+        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
+  const int col_size =
+        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
+
+  __shared__ float block_boxes[threadsPerBlock * 6];
+  if (threadIdx.x < col_size) {
+    block_boxes[threadIdx.x * 6 + 0] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0];
+    block_boxes[threadIdx.x * 6 + 1] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1];
+    block_boxes[threadIdx.x * 6 + 2] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2];
+    block_boxes[threadIdx.x * 6 + 3] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3];
+    block_boxes[threadIdx.x * 6 + 4] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4];
+    block_boxes[threadIdx.x * 6 + 5] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5];
+  }
+  __syncthreads();
+
+  if (threadIdx.x < row_size) {
+    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
+    const float *cur_box = dev_boxes + cur_box_idx * 6;
+    int i = 0;
+    unsigned long long t = 0;
+    int start = 0;
+    if (row_start == col_start) {
+      start = threadIdx.x + 1;
+    }
+    for (i = start; i < col_size; i++) {
+      if (devIoU(cur_box, block_boxes + i * 6) > nms_overlap_thresh) {
+        t |= 1ULL << i;
+      }
+    }
+    const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
+    dev_mask[cur_box_idx * col_blocks + col_start] = t;
+  }
+}
+
+// boxes is a N x 6 tensor
+at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
+  using scalar_t = float;
+  AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
+  auto scores = boxes.select(1, 4);
+  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
+  auto boxes_sorted = boxes.index_select(0, order_t);
+
+  int boxes_num = boxes.size(0);
+
+  const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
+
+  scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();
+
+  THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
+
+  unsigned long long* mask_dev = NULL;
+  //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
+  //                      boxes_num * col_blocks * sizeof(unsigned long long)));
+
+  mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
+
+  dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
+              THCCeilDiv(boxes_num, threadsPerBlock));
+  dim3 threads(threadsPerBlock);
+  ml_nms_kernel<<<blocks, threads>>>(boxes_num,
+                                  nms_overlap_thresh,
+                                  boxes_dev,
+                                  mask_dev);
+
+  std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
+  THCudaCheck(cudaMemcpy(&mask_host[0],
+                        mask_dev,
+                        sizeof(unsigned long long) * boxes_num * col_blocks,
+                        cudaMemcpyDeviceToHost));
+
+  std::vector<unsigned long long> remv(col_blocks);
+  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
+
+  at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
+  int64_t* keep_out = keep.data_ptr<int64_t>();
+
+  int num_to_keep = 0;
+  for (int i = 0; i < boxes_num; i++) {
+    int nblock = i / threadsPerBlock;
+    int inblock = i % threadsPerBlock;
+
+    if (!(remv[nblock] & (1ULL << inblock))) {
+      keep_out[num_to_keep++] = i;
+      unsigned long long *p = &mask_host[0] + i * col_blocks;
+      for (int j = nblock; j < col_blocks; j++) {
+        remv[j] |= p[j];
+      }
+    }
+  }
+
+  THCudaFree(state, mask_dev);
+  // TODO improve this part
+  return std::get<0>(order_t.index({
+                       keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
+                         order_t.device(), keep.scalar_type())
+                     }).sort(0, false));
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/nms.cu b/maskrcnn_benchmark/csrc/cuda/nms.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d6221b85fa8f6b40cf498b76d6dbfc3c8438e25e
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/nms.cu
@@ -0,0 +1,131 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+
+int const threadsPerBlock = sizeof(unsigned long long) * 8;
+
+__device__ inline float devIoU(float const * const a, float const * const b) {
+  float left = max(a[0], b[0]), right = min(a[2], b[2]);
+  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
+  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+  float interS = width * height;
+  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
+  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+  return interS / (Sa + Sb - interS);
+}
+
+__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
+                           const float *dev_boxes, unsigned long long *dev_mask) {
+  const int row_start = blockIdx.y;
+  const int col_start = blockIdx.x;
+
+  // if (row_start > col_start) return;
+
+  const int row_size =
+        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
+  const int col_size =
+        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
+
+  __shared__ float block_boxes[threadsPerBlock * 5];
+  if (threadIdx.x < col_size) {
+    block_boxes[threadIdx.x * 5 + 0] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
+    block_boxes[threadIdx.x * 5 + 1] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
+    block_boxes[threadIdx.x * 5 + 2] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
+    block_boxes[threadIdx.x * 5 + 3] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
+    block_boxes[threadIdx.x * 5 + 4] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
+  }
+  __syncthreads();
+
+  if (threadIdx.x < row_size) {
+    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
+    const float *cur_box = dev_boxes + cur_box_idx * 5;
+    int i = 0;
+    unsigned long long t = 0;
+    int start = 0;
+    if (row_start == col_start) {
+      start = threadIdx.x + 1;
+    }
+    for (i = start; i < col_size; i++) {
+      if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
+        t |= 1ULL << i;
+      }
+    }
+    const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
+    dev_mask[cur_box_idx * col_blocks + col_start] = t;
+  }
+}
+
+// boxes is a N x 5 tensor
+at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
+  using scalar_t = float;
+  AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
+  auto scores = boxes.select(1, 4);
+  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
+  auto boxes_sorted = boxes.index_select(0, order_t);
+
+  int boxes_num = boxes.size(0);
+
+  const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
+
+  scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();
+
+  THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
+
+  unsigned long long* mask_dev = NULL;
+  //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
+  //                      boxes_num * col_blocks * sizeof(unsigned long long)));
+
+  mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
+
+  dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
+              THCCeilDiv(boxes_num, threadsPerBlock));
+  dim3 threads(threadsPerBlock);
+  nms_kernel<<<blocks, threads>>>(boxes_num,
+                                  nms_overlap_thresh,
+                                  boxes_dev,
+                                  mask_dev);
+
+  std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
+  THCudaCheck(cudaMemcpy(&mask_host[0],
+                        mask_dev,
+                        sizeof(unsigned long long) * boxes_num * col_blocks,
+                        cudaMemcpyDeviceToHost));
+
+  std::vector<unsigned long long> remv(col_blocks);
+  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
+
+  at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
+  int64_t* keep_out = keep.data_ptr<int64_t>();
+
+  int num_to_keep = 0;
+  for (int i = 0; i < boxes_num; i++) {
+    int nblock = i / threadsPerBlock;
+    int inblock = i % threadsPerBlock;
+
+    if (!(remv[nblock] & (1ULL << inblock))) {
+      keep_out[num_to_keep++] = i;
+      unsigned long long *p = &mask_host[0] + i * col_blocks;
+      for (int j = nblock; j < col_blocks; j++) {
+        remv[j] |= p[j];
+      }
+    }
+  }
+
+  THCudaFree(state, mask_dev);
+  // TODO improve this part
+  return std::get<0>(order_t.index({
+                       keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
+                         order_t.device(), keep.scalar_type())
+                     }).sort(0, false));
+}
diff --git a/maskrcnn_benchmark/csrc/cuda/vision.h b/maskrcnn_benchmark/csrc/cuda/vision.h
new file mode 100644
index 0000000000000000000000000000000000000000..16a7f644ed5798d1917d32cda0590161b6da8c64
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/cuda/vision.h
@@ -0,0 +1,116 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include <torch/extension.h>
+
+
+at::Tensor SigmoidFocalLoss_forward_cuda(
+		const at::Tensor& logits,
+                const at::Tensor& targets,
+		const int num_classes, 
+		const float gamma, 
+		const float alpha); 
+
+at::Tensor SigmoidFocalLoss_backward_cuda(
+			     const at::Tensor& logits,
+                             const at::Tensor& targets,
+			     const at::Tensor& d_losses,
+			     const int num_classes,
+			     const float gamma,
+			     const float alpha);
+
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
+                                 const at::Tensor& rois,
+                                 const float spatial_scale,
+                                 const int pooled_height,
+                                 const int pooled_width,
+                                 const int sampling_ratio);
+
+at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
+                                  const at::Tensor& rois,
+                                  const float spatial_scale,
+                                  const int pooled_height,
+                                  const int pooled_width,
+                                  const int batch_size,
+                                  const int channels,
+                                  const int height,
+                                  const int width,
+                                  const int sampling_ratio);
+
+
+std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
+                                const at::Tensor& rois,
+                                const float spatial_scale,
+                                const int pooled_height,
+                                const int pooled_width);
+
+at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
+                                 const at::Tensor& input,
+                                 const at::Tensor& rois,
+                                 const at::Tensor& argmax,
+                                 const float spatial_scale,
+                                 const int pooled_height,
+                                 const int pooled_width,
+                                 const int batch_size,
+                                 const int channels,
+                                 const int height,
+                                 const int width);
+
+at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
+at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias);
+
+void deform_psroi_pooling_cuda_forward(
+    at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
+    at::Tensor top_count, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std);
+
+void deform_psroi_pooling_cuda_backward(
+    at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
+    at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
+    const int no_trans, const float spatial_scale, const int output_dim,
+    const int group_size, const int pooled_size, const int part_size,
+    const int sample_per_part, const float trans_std);
+
+
+at::Tensor compute_flow_cuda(const at::Tensor& boxes,
+                             const int height,
+                             const int width);
diff --git a/maskrcnn_benchmark/csrc/deform_conv.h b/maskrcnn_benchmark/csrc/deform_conv.h
new file mode 100644
index 0000000000000000000000000000000000000000..56452c18cb8677ed964ca08c9e6e68b368da39a6
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/deform_conv.h
@@ -0,0 +1,191 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+
+// Interface for Python
+int deform_conv_forward(
+    at::Tensor input, 
+    at::Tensor weight,
+    at::Tensor offset, 
+    at::Tensor output,
+    at::Tensor columns, 
+    at::Tensor ones, 
+    int kW,
+    int kH, 
+    int dW, 
+    int dH, 
+    int padW, 
+    int padH,
+    int dilationW, 
+    int dilationH, 
+    int group,
+    int deformable_group, 
+    int im2col_step)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_forward_cuda(
+        input, weight, offset, output, columns, ones,
+        kW, kH, dW, dH, padW, padH, dilationW, dilationH,
+        group, deformable_group, im2col_step
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+int deform_conv_backward_input(
+    at::Tensor input, 
+    at::Tensor offset,
+    at::Tensor gradOutput, 
+    at::Tensor gradInput,
+    at::Tensor gradOffset, 
+    at::Tensor weight,
+    at::Tensor columns, 
+    int kW, 
+    int kH, 
+    int dW,
+    int dH, 
+    int padW, 
+    int padH, 
+    int dilationW,
+    int dilationH, 
+    int group,
+    int deformable_group, 
+    int im2col_step)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_backward_input_cuda(
+        input, offset, gradOutput, gradInput, gradOffset, weight, columns,
+        kW, kH, dW, dH, padW, padH, dilationW, dilationH, 
+        group, deformable_group, im2col_step
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+int deform_conv_backward_parameters(
+    at::Tensor input, 
+    at::Tensor offset, 
+    at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, 
+    at::Tensor ones, 
+    int kW, 
+    int kH, 
+    int dW, 
+    int dH,
+    int padW, 
+    int padH, 
+    int dilationW, 
+    int dilationH, 
+    int group,
+    int deformable_group, 
+    float scale, 
+    int im2col_step)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_backward_parameters_cuda(
+        input, offset, gradOutput, gradWeight, columns, ones,
+        kW, kH, dW, dH, padW, padH, dilationW, dilationH,
+        group, deformable_group, scale, im2col_step
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+void modulated_deform_conv_forward(
+    at::Tensor input, 
+    at::Tensor weight, 
+    at::Tensor bias, 
+    at::Tensor ones,
+    at::Tensor offset, 
+    at::Tensor mask, 
+    at::Tensor output, 
+    at::Tensor columns,
+    int kernel_h, 
+    int kernel_w, 
+    const int stride_h, 
+    const int stride_w,
+    const int pad_h, 
+    const int pad_w, 
+    const int dilation_h,
+    const int dilation_w, 
+    const int group, 
+    const int deformable_group,
+    const bool with_bias)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return modulated_deform_conv_cuda_forward(
+        input, weight, bias, ones, offset, mask, output, columns,
+        kernel_h, kernel_w, stride_h, stride_w, 
+        pad_h, pad_w, dilation_h, dilation_w,
+        group, deformable_group, with_bias
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+void modulated_deform_conv_backward(
+    at::Tensor input, 
+    at::Tensor weight, 
+    at::Tensor bias, 
+    at::Tensor ones,
+    at::Tensor offset, 
+    at::Tensor mask, 
+    at::Tensor columns,
+    at::Tensor grad_input, 
+    at::Tensor grad_weight, 
+    at::Tensor grad_bias,
+    at::Tensor grad_offset, 
+    at::Tensor grad_mask, 
+    at::Tensor grad_output,
+    int kernel_h, 
+    int kernel_w, 
+    int stride_h, 
+    int stride_w, 
+    int pad_h,
+    int pad_w, 
+    int dilation_h, 
+    int dilation_w, 
+    int group, 
+    int deformable_group,
+    const bool with_bias)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return modulated_deform_conv_cuda_backward(
+        input, weight, bias, ones, offset, mask, columns, 
+        grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output,
+        kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
+        group, deformable_group, with_bias
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/csrc/deform_pool.h b/maskrcnn_benchmark/csrc/deform_pool.h
new file mode 100644
index 0000000000000000000000000000000000000000..b3379e205caa43d854447ba896ce5848ccd65c89
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/deform_pool.h
@@ -0,0 +1,70 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+
+// Interface for Python
+void deform_psroi_pooling_forward(
+    at::Tensor input, 
+    at::Tensor bbox, 
+    at::Tensor trans, 
+    at::Tensor out,
+    at::Tensor top_count, 
+    const int no_trans, 
+    const float spatial_scale,
+    const int output_dim, 
+    const int group_size, 
+    const int pooled_size,
+    const int part_size, 
+    const int sample_per_part, 
+    const float trans_std)
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_psroi_pooling_cuda_forward(
+        input, bbox, trans, out, top_count, 
+        no_trans, spatial_scale, output_dim, group_size,
+        pooled_size, part_size, sample_per_part, trans_std
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
+
+
+void deform_psroi_pooling_backward(
+    at::Tensor out_grad, 
+    at::Tensor input, 
+    at::Tensor bbox, 
+    at::Tensor trans,
+    at::Tensor top_count, 
+    at::Tensor input_grad, 
+    at::Tensor trans_grad,
+    const int no_trans, 
+    const float spatial_scale, 
+    const int output_dim,
+    const int group_size, 
+    const int pooled_size, 
+    const int part_size,
+    const int sample_per_part, 
+    const float trans_std) 
+{
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_psroi_pooling_cuda_backward(
+        out_grad, input, bbox, trans, top_count, input_grad, trans_grad,
+        no_trans, spatial_scale, output_dim, group_size, pooled_size, 
+        part_size, sample_per_part, trans_std
+    );
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("Not implemented on the CPU");
+}
diff --git a/maskrcnn_benchmark/csrc/ml_nms.h b/maskrcnn_benchmark/csrc/ml_nms.h
new file mode 100644
index 0000000000000000000000000000000000000000..bb4370d0576a3280b324ae69257f41789dd2416d
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/ml_nms.h
@@ -0,0 +1,27 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+
+at::Tensor ml_nms(const at::Tensor& dets,
+                  const at::Tensor& scores,
+                  const at::Tensor& labels,
+                  const float threshold) {
+
+  if (dets.device().is_cuda()) {
+#ifdef WITH_CUDA
+    // TODO raise error if not compiled with CUDA
+    if (dets.numel() == 0)
+      return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
+    auto b = at::cat({dets, scores.unsqueeze(1), labels.unsqueeze(1)}, 1);
+    return ml_nms_cuda(b, threshold);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("CPU version not implemented");
+}
diff --git a/maskrcnn_benchmark/csrc/nms.h b/maskrcnn_benchmark/csrc/nms.h
new file mode 100644
index 0000000000000000000000000000000000000000..cb86028949747e215a8f5c74d768ece8937f4f81
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/nms.h
@@ -0,0 +1,45 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+#include "cpu/vision.h"
+
+#ifdef WITH_CUDA
+#include "cuda/vision.h"
+#endif
+
+
+at::Tensor nms(const at::Tensor& dets,
+               const at::Tensor& scores,
+               const float threshold) {
+
+  if (dets.device().is_cuda()) {
+#ifdef WITH_CUDA
+    // TODO raise error if not compiled with CUDA
+    if (dets.numel() == 0)
+      return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
+    auto b = at::cat({dets, scores.unsqueeze(1)}, 1);
+    return nms_cuda(b, threshold);
+#else
+    AT_ERROR("Not compiled with GPU support");
+#endif
+  }
+
+  at::Tensor result = nms_cpu(dets, scores, threshold);
+  return result;
+}
+
+
+std::pair<at::Tensor, at::Tensor> soft_nms(const at::Tensor& dets,
+                                           const at::Tensor& scores,
+                                           const float threshold,
+                                           const float sigma) {
+
+  if (dets.device().is_cuda()) {
+#ifdef WITH_CUDA
+    AT_ERROR("Soft NMS Does Not have GPU support");
+#endif
+  }
+
+  std::pair<at::Tensor, at::Tensor> result = soft_nms_cpu(dets, scores, threshold, sigma);
+
+  return result;
+}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/csrc/vision.cpp b/maskrcnn_benchmark/csrc/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a5bd4751b67aa35f7649dd3f5b733982e38088d1
--- /dev/null
+++ b/maskrcnn_benchmark/csrc/vision.cpp
@@ -0,0 +1,27 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include "nms.h"
+#include "ml_nms.h"
+#include "ROIAlign.h"
+#include "ROIPool.h"
+#include "SigmoidFocalLoss.h"
+#include "deform_conv.h"
+#include "deform_pool.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("nms", &nms, "non-maximum suppression");
+  m.def("ml_nms", &ml_nms, "multi-label non-maximum suppression");
+  m.def("soft_nms", &soft_nms, "soft non-maximum suppression");
+  m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
+  m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
+  m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
+  m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
+  m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward");
+  m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward");
+  m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
+  m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input");
+  m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters");
+  m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward");
+  m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward");
+  m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward");
+  m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward");
+}
diff --git a/maskrcnn_benchmark/data/__init__.py b/maskrcnn_benchmark/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0210bc1653fd56b4fcea06e22f185ffaa57e06
--- /dev/null
+++ b/maskrcnn_benchmark/data/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .build import make_data_loader
diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..14b5973b5642d9d1d99093887a49bda869d0246a
--- /dev/null
+++ b/maskrcnn_benchmark/data/build.py
@@ -0,0 +1,489 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import bisect
+import copy
+import logging
+import os
+
+import torch.utils.data
+import torch.distributed as dist
+from maskrcnn_benchmark.utils.comm import get_world_size
+from maskrcnn_benchmark.utils.imports import import_file
+
+from . import datasets as D
+from . import samplers
+
+from .collate_batch import BatchCollator, BBoxAugCollator
+from .transforms import build_transforms
+
+from transformers import AutoTokenizer
+from .datasets.duplicate_dataset import create_duplicate_dataset
+
+def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True, class_concat=False, extra_args={}):
+    """
+    Arguments:
+        dataset_list (list[str]): Contains the names of the datasets, i.e.,
+            coco_2014_trian, coco_2014_val, etc
+        transforms (callable): transforms to apply to each (image, target) sample
+        dataset_catalog (DatasetCatalog): contains the information on how to
+            construct a dataset.
+        is_train (bool): whether to setup the dataset for training or testing
+    """
+    if not isinstance(dataset_list, (list, tuple)):
+        raise RuntimeError(
+            "dataset_list should be a list of strings, got {}".format(dataset_list)
+        )
+    datasets = []
+    num_category = 1
+    for dataset_id, dataset_name in enumerate(dataset_list, 1):
+        if is_train:
+            dataset_name = dataset_name + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
+        else:
+            dataset_name = dataset_name + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
+        data = dataset_catalog.get(dataset_name)
+        factory = getattr(D, data["factory"])
+        args = data["args"]
+        # for COCODataset, we want to remove images without annotations
+        # during training
+        if data["factory"] == "COCODataset":
+            args["remove_images_without_annotations"] = is_train
+
+        if data["factory"] == "PascalVOCDataset":
+            args["use_difficult"] = not is_train
+        if data["factory"] in ["VGTSVDataset", "CocoDetectionTSV", "ODTSVDataset"]:
+            args["extra_fields"] = ["class"]
+            if cfg.MODEL.MASK_ON:
+                args["extra_fields"].append("mask")
+
+        if data["factory"] in ["CocoGrounding", "CocoDetectionTSV", "CaptionTSV", "MixedDataset", "FlickrDataset", "RefExpDataset", "GQADataset", "PseudoData", "PhrasecutDetection"]:
+            # args["return_masks"] = False
+            args["return_masks"] = cfg.MODEL.MASK_ON
+            args["return_tokens"] = True
+            args["max_num_labels"] = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM
+            args["max_query_len"] = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
+
+        args["transforms"] = transforms
+        args.update(extra_args)
+
+        if dataset_name == "flickr30k_train":
+            copy = cfg.DATASETS.FLICKR_COPY
+        elif dataset_name in ["mixed_train", "mixed_train_no_coco"]:
+            copy = cfg.DATASETS.MIXED_COPY
+        elif dataset_name == "COCO_odinw_train_8copy_dt_train":
+            copy = cfg.DATASETS.COCO_COPY
+        elif dataset_name == "LVIS_odinw_train_8copy_dt_train":
+            copy = cfg.DATASETS.LVIS_COPY
+        elif dataset_name == "object365_odinw_2copy_dt_train":
+            copy = cfg.DATASETS.OBJECT365_COPY
+        elif dataset_name == "vg_odinw_clipped_8copy_dt_train":
+            copy = cfg.DATASETS.VG_COPY
+        elif dataset_name == "vg_vgoi6_clipped_8copy_dt_train":
+            copy = cfg.DATASETS.VG_COPY
+        elif dataset_name == "imagenetod_train_odinw_2copy_dt":
+            copy = cfg.DATASETS.IN_COPY
+        elif dataset_name == "oi_train_odinw_dt":
+            copy = cfg.DATASETS.OI_COPY
+        elif is_train:
+            copy = cfg.DATASETS.GENERAL_COPY
+        elif not is_train:
+            copy = cfg.DATASETS.GENERAL_COPY_TEST
+        else:
+            copy = -1 # do not ever copy test
+        
+        if copy != -1:
+            new_factory = create_duplicate_dataset(factory)
+            dataset = new_factory(copy=copy, **args)
+        else:
+            # make dataset from factory
+            dataset = factory(**args)
+
+        print(dataset_name, 'has the {} data points'.format(len(dataset)), data["factory"])
+
+        if class_concat:
+            category = list(dataset.contiguous_category_id_to_json_id.values())
+            dataset.contiguous_category_id_to_json_id = {}
+            dataset.json_category_id_to_contiguous_id = {}
+            for id, cat in enumerate(category, start=num_category):
+                dataset.json_category_id_to_contiguous_id[cat] = id
+                dataset.contiguous_category_id_to_json_id[id] = cat
+            num_category += len(category)
+            print("Found {} #category after group {}, concating ...".format(num_category, dataset_id))
+        datasets.append(dataset)
+
+    # for testing, return a list of datasets
+    if not is_train:
+        return datasets
+
+    # for training, concatenate all datasets into a single one
+    dataset = datasets[0]
+    if len(datasets) > 1:
+        dataset = D.ConcatDataset(datasets)
+
+    return [dataset]
+
+
+def build_dataset_by_group(dataset_list, transforms, dataset_catalog, is_train=True, class_by_group=True,
+                           class_concat=False, extra_args={}):
+    """
+    Arguments:
+        dataset_list (list[str]): Contains the names of the datasets, i.e.,
+            coco_2014_trian, coco_2014_val, etc
+        transforms (callable): transforms to apply to each (image, target) sample
+        dataset_catalog (DatasetCatalog): contains the information on how to
+            construct a dataset.
+        is_train (bool): whether to setup the dataset for training or testing
+    """
+    if not isinstance(dataset_list, (list, tuple)):
+        raise RuntimeError(
+            "dataset_list should be a list of strings, got {}".format(dataset_list)
+        )
+
+    num_category = 1
+    grouped_datasets = []
+    for group_id, group in enumerate(dataset_list, 1):
+        datasets = []
+        for dataset_name in group:
+            data = dataset_catalog.get(dataset_name)
+            factory = getattr(D, data["factory"])
+            args = data["args"]
+            # for COCODataset, we want to remove images without annotations
+            # during training
+            if data["factory"] == "COCODataset":
+                args["remove_images_without_annotations"] = is_train
+            if data["factory"] == "PascalVOCDataset":
+                args["use_difficult"] = not is_train
+            args["transforms"] = transforms
+            args.update(extra_args)
+            # make dataset from factory
+            dataset = factory(**args)
+
+            # check if dataset is grouped by task, assume one class per task
+            if class_by_group and data["factory"] != "Background":
+                category = dataset.contiguous_category_id_to_json_id[1]
+                del dataset.contiguous_category_id_to_json_id[1]
+                dataset.json_category_id_to_contiguous_id[category] = group_id
+                dataset.contiguous_category_id_to_json_id[group_id] = category
+
+            datasets.append(dataset)
+
+        if class_concat:
+            for dataset in datasets:
+                category = list(dataset.contiguous_category_id_to_json_id.values())
+                dataset.contiguous_category_id_to_json_id = {}
+                dataset.json_category_id_to_contiguous_id = {}
+                for id, cat in enumerate(category, start=num_category):
+                    dataset.json_category_id_to_contiguous_id[cat] = id
+                    dataset.contiguous_category_id_to_json_id[id] = cat
+            num_category += len(category)
+            print("Found {} #category after group {}, concating ...".format(num_category, group_id))
+
+        if is_train:
+            datasets = D.ConcatDataset(datasets)
+
+        grouped_datasets.append(datasets)
+
+    # for testing, return a list of datasets
+    if not is_train:
+        datasets = [dataset for group in grouped_datasets for dataset in group]
+        return datasets
+    if class_concat:
+        grouped_datasets = D.ConcatDataset(grouped_datasets)
+        return [grouped_datasets]
+
+    # for training, concatenate all datasets into a single one
+    return grouped_datasets
+
+
+def make_data_sampler(dataset, shuffle, distributed, num_replicas=None, rank=None, use_random_seed=True):
+    if distributed:
+        return samplers.DistributedSampler(dataset, shuffle=shuffle, num_replicas=num_replicas, rank=rank,
+                                           use_random=use_random_seed)
+    if shuffle:
+        sampler = torch.utils.data.sampler.RandomSampler(dataset)
+    else:
+        sampler = torch.utils.data.sampler.SequentialSampler(dataset)
+    return sampler
+
+
+def _quantize(x, bins):
+    bins = copy.copy(bins)
+    bins = sorted(bins)
+    quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
+    return quantized
+
+
+def _compute_aspect_ratios(dataset):
+    aspect_ratios = []
+    for i in range(len(dataset)):
+        img_info = dataset.get_img_info(i)
+        aspect_ratio = float(img_info["height"]) / float(img_info["width"])
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def make_batch_data_sampler(
+        dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, drop_last=False
+):
+    if aspect_grouping:
+        if not isinstance(aspect_grouping, (list, tuple)):
+            aspect_grouping = [aspect_grouping]
+        aspect_ratios = _compute_aspect_ratios(dataset)
+        group_ids = _quantize(aspect_ratios, aspect_grouping)
+        batch_sampler = samplers.GroupedBatchSampler(
+            sampler, group_ids, images_per_batch, drop_uneven=drop_last
+        )
+    else:
+        batch_sampler = torch.utils.data.sampler.BatchSampler(
+            sampler, images_per_batch, drop_last=drop_last
+        )
+    if num_iters is not None:
+        batch_sampler = samplers.IterationBasedBatchSampler(
+            batch_sampler, num_iters, start_iter
+        )
+    return batch_sampler
+
+def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0):
+    num_gpus = num_replicas or get_world_size()
+
+    if is_train:
+        images_per_batch = cfg.SOLVER.IMS_PER_BATCH
+        assert (
+                images_per_batch % num_gpus == 0
+        ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number "
+        "of GPUs ({}) used.".format(images_per_batch, num_gpus)
+        images_per_gpu = images_per_batch // num_gpus
+        shuffle = True
+        num_iters = cfg.SOLVER.MAX_ITER
+    else:
+        images_per_batch = cfg.TEST.IMS_PER_BATCH
+        assert (
+                images_per_batch % num_gpus == 0
+        ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number "
+        "of GPUs ({}) used.".format(images_per_batch, num_gpus)
+        images_per_gpu = images_per_batch // num_gpus
+        shuffle = False if not is_distributed else True
+        num_iters = None
+        start_iter = 0
+
+    if images_per_gpu > 1:
+        logger = logging.getLogger(__name__)
+        logger.warning(
+            "When using more than one image per GPU you may encounter "
+            "an out-of-memory (OOM) error if your GPU does not have "
+            "sufficient memory. If this happens, you can reduce "
+            "SOLVER.IMS_PER_BATCH (for training) or "
+            "TEST.IMS_PER_BATCH (for inference). For training, you must "
+            "also adjust the learning rate and schedule length according "
+            "to the linear scaling rule. See for example: "
+            "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14"
+        )
+
+    # group images which have similar aspect ratio. In this case, we only
+    # group in two cases: those with width / height > 1, and the other way around,
+    # but the code supports more general grouping strategy
+    aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []
+
+    paths_catalog = import_file(
+        "maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
+    )
+
+    DatasetCatalog = paths_catalog.DatasetCatalog
+    if len(cfg.DATASETS.REGISTER) > 0:
+        for new_dataset in cfg.DATASETS.REGISTER:
+            # img_dir = cfg.DATASETS.REGISTER[new_dataset]["img_dir"]
+            # if "ann_file" in cfg.DATASETS.REGISTER[new_dataset]:
+            #     ann_file = cfg.DATASETS.REGISTER[new_dataset]["ann_file"]
+            # else:
+            #     ann_file = None
+            attrs = dict(cfg.DATASETS.REGISTER[new_dataset])
+            if is_train:
+                new_dataset = new_dataset + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
+            else:
+                new_dataset = new_dataset + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
+            DatasetCatalog.set(new_dataset, attrs)
+
+
+    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
+
+    # Haotian: expand bing dataset
+    if "bing_caption_train" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
+        dataset_list = list(dataset_list)
+        dataset_list.remove("bing_caption_train")
+        for bing_index in cfg.DATASETS.BING_INDEX_LIST:
+            dataset_list.insert(len(dataset_list), "bing_caption_{}_train".format(bing_index))
+        dataset_list = tuple(dataset_list)
+    
+    if "bing_caption_train_no_coco" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
+        dataset_list = list(dataset_list)
+        dataset_list.remove("bing_caption_train_no_coco")
+        for bing_index in cfg.DATASETS.BING_INDEX_LIST:
+            dataset_list.insert(len(dataset_list), "bing_caption_{}_train_no_coco".format(bing_index))
+        dataset_list = tuple(dataset_list)
+
+    print("The combined datasets are: {}.".format(dataset_list))
+
+    transforms = None if not is_train and cfg.TEST.USE_MULTISCALE else build_transforms(cfg, is_train)
+
+    extra_args = {}
+    if is_train and cfg.DATASETS.USE_CROWD:
+        extra_args['ignore_crowd'] = False
+    if is_train and cfg.DATASETS.MAX_BOX > 0:
+        extra_args['max_box'] = cfg.DATASETS.MAX_BOX
+    if is_train and cfg.DATASETS.FEW_SHOT>0:
+        extra_args['few_shot'] = cfg.DATASETS.FEW_SHOT
+    if is_train and cfg.DATASETS.SHUFFLE_SEED != 0:
+        extra_args['shuffle_seed'] = cfg.DATASETS.SHUFFLE_SEED
+
+    # od to grounding
+    if is_train and cfg.DATASETS.RANDOM_SAMPLE_NEG > 0:
+        extra_args['random_sample_negative'] = cfg.DATASETS.RANDOM_SAMPLE_NEG
+    if is_train and cfg.DATASETS.ADD_DET_PROMPT:
+        extra_args["add_detection_prompt"] = True
+    if is_train and cfg.DATASETS.USE_OD_AUG:
+        extra_args["use_od_data_aug"] = True
+    if is_train and cfg.DATASETS.DISABLE_SHUFFLE:
+        extra_args["disable_shuffle"] = True
+    if cfg.DATASETS.ONE_HOT:
+        extra_args["one_hot"] = True
+    if is_train and len(cfg.DATASETS.PROMPT_VERSION) > 0:
+        extra_args["prompt_engineer_version"] = cfg.DATASETS.PROMPT_VERSION
+    if is_train and len(cfg.DATASETS.CONTROL_PROB) == 4:
+        extra_args["control_probabilities"] = cfg.DATASETS.CONTROL_PROB
+    if is_train and cfg.DATASETS.DISABLE_CLIP_TO_IMAGE:
+        extra_args["disable_clip_to_image"] =  cfg.DATASETS.DISABLE_CLIP_TO_IMAGE
+    if is_train and cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT:
+        extra_args["no_minus_one_for_one_hot"] = cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT
+    if is_train:
+        extra_args["separation_tokens"] = cfg.DATASETS.SEPARATION_TOKENS
+    # caption
+    if is_train and cfg.DATASETS.CAPTION_MIN_BOX > 0:
+        extra_args["caption_min_box"] = cfg.DATASETS.CAPTION_MIN_BOX
+    if is_train and cfg.DATASETS.REPLACE_CLEAN_LABEL:
+        extra_args["replace_clean_label"] = True
+    if is_train and cfg.DATASETS.FURTHER_SCREEN:
+        extra_args["further_screen"] = True
+    if is_train and cfg.DATASETS.CAPTION_CONF > 0.0:
+        extra_args["caption_conf"] = cfg.DATASETS.CAPTION_CONF
+    if is_train:
+        extra_args["caption_nms"] = cfg.DATASETS.CAPTION_NMS
+    if is_train and cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER > 0:
+        extra_args["pack_random_caption_number"] = cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER
+    if is_train and cfg.DATASETS.INFERENCE_CAPTION:
+        extra_args["inference_caption"] = True
+    if is_train and cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA > 0:
+        extra_args["sample_negative_for_grounding_data"] = cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA
+    if is_train and cfg.DATASETS.RANDOM_PACK_PROB > 0:
+        extra_args["random_pack_prob"] = cfg.DATASETS.RANDOM_PACK_PROB
+    if is_train and cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY > 0:
+        extra_args["no_random_pack_probability"] = cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY
+    if is_train:
+        extra_args["safeguard_positive_caption"] = cfg.DATASETS.SAFEGUARD_POSITIVE_CAPTION
+    if is_train:
+        extra_args["local_debug"] = cfg.DATASETS.LOCAL_DEBUG
+    if is_train:
+        extra_args["no_mask_for_od"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD
+    if is_train:
+        extra_args["no_mask_for_gold"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD
+    if is_train:
+        extra_args["mlm_obj_for_only_positive"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE
+    if cfg.DATASETS.OVERRIDE_CATEGORY and cfg.DATASETS.USE_OVERRIDE_CATEGORY:
+        extra_args["override_category"] = cfg.DATASETS.OVERRIDE_CATEGORY
+    if is_train:
+        extra_args["caption_format_version"] = cfg.DATASETS.CAPTION_FORMAT_VERSION
+    if is_train:
+        extra_args["special_safeguard_for_coco_grounding"] = cfg.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING
+    if is_train:
+        extra_args["diver_box_for_vqa"] = cfg.DATASETS.DIVER_BOX_FOR_VQA
+    extra_args["caption_prompt"] = cfg.DATASETS.CAPTION_PROMPT
+    extra_args["use_caption_prompt"] = cfg.DATASETS.USE_CAPTION_PROMPT
+
+    # extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
+    if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+        # extra_args['tokenizer'] = build_tokenizer("clip")
+        from transformers import CLIPTokenizerFast
+        if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+            extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True, mask_token='ðŁĴij</w>')
+        else:
+            extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True)
+    else:
+        extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
+
+    if isinstance(dataset_list[0], (tuple, list)):
+        datasets = build_dataset_by_group(dataset_list, transforms, DatasetCatalog, is_train,
+                                          class_by_group=cfg.DATASETS.ALTERNATIVE_TRAINING,
+                                          class_concat=cfg.DATASETS.CLASS_CONCAT,
+                                          extra_args=extra_args)
+    else:
+        datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, is_train,
+                                 class_concat=cfg.DATASETS.CLASS_CONCAT,
+                                 extra_args=extra_args)
+
+    data_loaders = []
+    for di, dataset in enumerate(datasets):
+        if is_train and cfg.SOLVER.MAX_EPOCH > 0:
+            num_iters = cfg.SOLVER.MAX_EPOCH * len(dataset) // cfg.SOLVER.IMS_PER_BATCH
+            print("Number of iterations are {}".format(num_iters))
+            cfg.defrost()
+            cfg.SOLVER.MAX_ITER = num_iters
+            cfg.SOLVER.DATASET_LENGTH = len(dataset)
+            cfg.freeze()
+        if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
+            num_iters = None
+            cfg.defrost()
+            cfg.SOLVER.MULTI_MAX_ITER += (cfg.SOLVER.MULTI_MAX_EPOCH[di] * len(dataset) // cfg.SOLVER.IMS_PER_BATCH,)
+            cfg.freeze()
+
+        if is_train and cfg.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE:
+            from .datasets.custom_distributed_sampler import DistributedSamplerChunkByNode
+            chunk_or_not = []
+            for i in dataset_list:
+                if "bing_caption" in i:
+                    chunk_or_not.append(True)
+                else:
+                    chunk_or_not.append(False)
+            assert(len(chunk_or_not) == len(dataset.datasets))
+            '''
+            If we are training on 4 nodes, each with 8 GPUs
+            '''
+            num_nodes = int(os.getenv('NODE_COUNT', os.getenv('OMPI_COMM_WORLD_SIZE', 1)))
+            local_size = cfg.num_gpus//num_nodes
+            node_rank = int(os.getenv('NODE_RANK', os.getenv('OMPI_COMM_WORLD_RANK', 0)))
+            local_rank = cfg.local_rank
+            sampler = DistributedSamplerChunkByNode(
+                dataset = dataset,
+                all_datasets = dataset.datasets, # Assumming dataset is a ConcateDataset instance,
+                chunk_or_not = chunk_or_not,
+                num_replicas = cfg.num_gpus, # total GPU number, e.g., 32
+                rank = dist.get_rank(), # Global Rank, e.g., 0~31
+                node_rank = node_rank, # Node Rank, e.g., 0~3
+                node_number = num_nodes, # how many node e.g., 4
+                process_num_per_node = local_size, # e.g., 8
+                rank_within_local_node = local_rank, # e.g., 0~7
+            )
+        else:
+            sampler = make_data_sampler(dataset, shuffle, is_distributed, num_replicas=num_replicas, rank=rank,
+                                        use_random_seed=cfg.DATALOADER.USE_RANDOM_SEED)
+        batch_sampler = make_batch_data_sampler(
+            dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, drop_last=is_train
+        )
+        collator = BBoxAugCollator() if not is_train and cfg.TEST.USE_MULTISCALE else BatchCollator(
+            cfg.DATALOADER.SIZE_DIVISIBILITY)
+        num_workers = cfg.DATALOADER.NUM_WORKERS
+        data_loader = torch.utils.data.DataLoader(
+            dataset,
+            num_workers=num_workers,
+            batch_sampler=batch_sampler,
+            collate_fn=collator,
+        )
+        data_loaders.append(data_loader)
+    if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
+        cfg.defrost()
+        cfg.SOLVER.MULTI_MAX_ITER += (
+            cfg.SOLVER.MULTI_MAX_EPOCH[-1] * min([len(dataset) // cfg.SOLVER.IMS_PER_BATCH for dataset in datasets]),)
+        cfg.freeze()
+
+    if is_train and not cfg.DATASETS.ALTERNATIVE_TRAINING and not cfg.DATASETS.MULTISTAGE_TRAINING:
+        # during training, a single (possibly concatenated) data_loader is returned
+        assert len(data_loaders) == 1
+        return data_loaders[0]
+
+    return data_loaders
diff --git a/maskrcnn_benchmark/data/collate_batch.py b/maskrcnn_benchmark/data/collate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf08fd9b5fd67ef41e659bd6df8ae20933359435
--- /dev/null
+++ b/maskrcnn_benchmark/data/collate_batch.py
@@ -0,0 +1,93 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from maskrcnn_benchmark.structures.image_list import to_image_list
+
+import pdb
+class BatchCollator(object):
+    """
+    From a list of samples from the dataset,
+    returns the batched images and targets.
+    This should be passed to the DataLoader
+    """
+
+    def __init__(self, size_divisible=0):
+        self.size_divisible = size_divisible
+
+    def __call__(self, batch):
+        transposed_batch = list(zip(*batch))
+        
+        images = to_image_list(transposed_batch[0], self.size_divisible)
+        targets = transposed_batch[1]
+        img_ids = transposed_batch[2]
+        positive_map = None
+        positive_map_eval = None
+        greenlight_map = None
+
+        if isinstance(targets[0], dict):
+            return images, targets, img_ids, positive_map, positive_map_eval
+
+        if "greenlight_map" in transposed_batch[1][0].fields():
+            greenlight_map = torch.stack([i.get_field("greenlight_map") for i in transposed_batch[1]], dim = 0)
+
+        if "positive_map" in transposed_batch[1][0].fields():
+            # we batch the positive maps here
+            # Since in general each batch element will have a different number of boxes,
+            # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
+            max_len = max([v.get_field("positive_map").shape[1] for v in transposed_batch[1]])
+            nb_boxes = sum([v.get_field("positive_map").shape[0] for v in transposed_batch[1]])
+            batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
+            cur_count = 0
+            for v in transposed_batch[1]:
+                cur_pos = v.get_field("positive_map")
+                batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
+                cur_count += len(cur_pos)
+
+            assert cur_count == len(batched_pos_map)
+            positive_map = batched_pos_map.float()
+        
+
+        if "positive_map_eval" in transposed_batch[1][0].fields():
+            # we batch the positive maps here
+            # Since in general each batch element will have a different number of boxes,
+            # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
+            max_len = max([v.get_field("positive_map_eval").shape[1] for v in transposed_batch[1]])
+            nb_boxes = sum([v.get_field("positive_map_eval").shape[0] for v in transposed_batch[1]])
+            batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
+            cur_count = 0
+            for v in transposed_batch[1]:
+                cur_pos = v.get_field("positive_map_eval")
+                batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
+                cur_count += len(cur_pos)
+
+            assert cur_count == len(batched_pos_map)
+            # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]])
+            positive_map_eval = batched_pos_map.float()
+
+
+        return images, targets, img_ids, positive_map, positive_map_eval, greenlight_map
+
+
+class BBoxAugCollator(object):
+    """
+    From a list of samples from the dataset,
+    returns the images and targets.
+    Images should be converted to batched images in `im_detect_bbox_aug`
+    """
+
+    def __call__(self, batch):
+        # return list(zip(*batch))
+        transposed_batch = list(zip(*batch))
+
+        images = transposed_batch[0]
+        targets = transposed_batch[1]
+        img_ids = transposed_batch[2]
+        positive_map = None
+        positive_map_eval = None
+
+        if isinstance(targets[0], dict):
+            return images, targets, img_ids, positive_map, positive_map_eval
+
+        return images, targets, img_ids, positive_map, positive_map_eval
+
+
+
diff --git a/maskrcnn_benchmark/data/datasets/__init__.py b/maskrcnn_benchmark/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1136cd63ccdf0cf6207226ab7fad98181e3aa0dc
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .coco import COCODataset
+from .voc import PascalVOCDataset
+from .concat_dataset import ConcatDataset
+from .background import Background
+from .tsv import TSVDataset, ODTSVDataset
+
+from .modulated_coco import ModulatedDataset, CocoDetection, CocoGrounding
+from .flickr import FlickrDataset
+from .refexp import RefExpDataset
+from .mixed import MixedDataset
+from .gqa import GQADataset
+
+from .coco_dt import CocoDetectionTSV
+from .caption import CaptionTSV
+from .lvis import LvisDetection
+from .pseudo_data import PseudoData
+from .phrasecut import PhrasecutDetection
+
+__all__ = ["COCODataset", "TSVDataset", "ODTSVDataset", "ConcatDataset", "PascalVOCDataset", "Background",
+           "ModulatedDataset", "MixedDataset", "CocoDetection", "FlickrDataset", "RefExpDataset", "GQADataset",
+           "CocoDetectionTSV", "CocoGrounding", "CaptionTSV", "LvisDetection", "PseudoData", "PhrasecutDetection"
+           ]
diff --git a/maskrcnn_benchmark/data/datasets/background.py b/maskrcnn_benchmark/data/datasets/background.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f2051da45b046fc3481e6116d75769fbe42be0d
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/background.py
@@ -0,0 +1,53 @@
+import os
+import os.path
+import json
+from PIL import Image
+
+import torch
+import torchvision
+import torch.utils.data as data
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+class Background(data.Dataset):
+    """ Background
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+    """
+
+    def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None):
+        self.root = root
+
+        with open(ann_file, 'r') as f:
+            self.ids = json.load(f)['images']
+        self.transform = transforms
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        im_info = self.ids[index]
+        path = im_info['file_name']
+        fp = os.path.join(self.root, path)
+
+        img = Image.open(fp).convert('RGB')
+        if self.transform is not None:
+            img, _ = self.transform(img, None)
+        null_target = BoxList(torch.zeros((0,4)), (img.shape[-1], img.shape[-2]))
+        null_target.add_field('labels', torch.zeros(0))
+
+        return img, null_target, index
+
+    def __len__(self):
+        return len(self.ids)
+
+    def get_img_info(self, index):
+        im_info = self.ids[index]
+        return im_info
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/box_label_loader.py b/maskrcnn_benchmark/data/datasets/box_label_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d40c758aacd1915f97f163f4736ecd86311fcc0
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/box_label_loader.py
@@ -0,0 +1,251 @@
+import torch
+import numpy as np
+import math
+import base64
+import collections
+import pycocotools.mask as mask_utils
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
+
+
+class LabelLoader(object):
+    def __init__(self, labelmap, extra_fields=(), filter_duplicate_relations=False, ignore_attr=None, ignore_rel=None,
+                 mask_mode="poly"):
+        self.labelmap = labelmap
+        self.extra_fields = extra_fields
+        self.supported_fields = ["class", "conf", "attributes", 'scores_all', 'boxes_all', 'feature', "mask"]
+        self.filter_duplicate_relations = filter_duplicate_relations
+        self.ignore_attr = set(ignore_attr) if ignore_attr != None else set()
+        self.ignore_rel = set(ignore_rel) if ignore_rel != None else set()
+        assert mask_mode == "poly" or mask_mode == "mask"
+        self.mask_mode = mask_mode
+
+    def __call__(self, annotations, img_size, remove_empty=False, load_fields=None):
+        boxes = [obj["rect"] for obj in annotations]
+        boxes = torch.as_tensor(boxes).reshape(-1, 4)
+        target = BoxList(boxes, img_size, mode="xyxy")
+
+        if load_fields is None:
+            load_fields = self.extra_fields
+
+        for field in load_fields:
+            assert field in self.supported_fields, "Unsupported field {}".format(field)
+            if field == "class":
+                classes = self.add_classes(annotations)
+                target.add_field("labels", classes)
+            elif field == "conf":
+                confidences = self.add_confidences(annotations)
+                target.add_field("scores", confidences)
+            elif field == "attributes":
+                attributes = self.add_attributes(annotations)
+                target.add_field("attributes", attributes)
+            elif field == "scores_all":
+                scores_all = self.add_scores_all(annotations)
+                target.add_field("scores_all", scores_all)
+            elif field == "boxes_all":
+                boxes_all = self.add_boxes_all(annotations)
+                target.add_field("boxes_all", boxes_all)
+            elif field == "feature":
+                features = self.add_features(annotations)
+                target.add_field("box_features", features)
+            elif field == "mask":
+                masks, is_box_mask = self.add_masks(annotations, img_size)
+                target.add_field("masks", masks)
+                target.add_field("is_box_mask", is_box_mask)
+
+        target = target.clip_to_image(remove_empty=remove_empty)
+        return target
+
+    def get_box_mask(self, rect, img_size):
+        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
+        if self.mask_mode == "poly":
+            return [[x1, y1, x1, y2, x2, y2, x2, y1]]
+        elif self.mask_mode == "mask":
+            # note the order of height/width order in mask is opposite to image
+            mask = np.zeros([img_size[1], img_size[0]], dtype=np.uint8)
+            mask[math.floor(y1):math.ceil(y2), math.floor(x1):math.ceil(x2)] = 255
+            encoded_mask = mask_utils.encode(np.asfortranarray(mask))
+            encoded_mask["counts"] = encoded_mask["counts"].decode("utf-8")
+            return encoded_mask
+
+    def add_masks(self, annotations, img_size):
+        masks = []
+        is_box_mask = []
+        for obj in annotations:
+            if "mask" in obj:
+                masks.append(obj["mask"])
+                is_box_mask.append(0)
+            else:
+                masks.append(self.get_box_mask(obj["rect"], img_size))
+                is_box_mask.append(1)
+        masks = SegmentationMask(masks, img_size, mode=self.mask_mode)
+        is_box_mask = torch.tensor(is_box_mask)
+        return masks, is_box_mask
+
+    def add_classes(self, annotations):
+        class_names = [obj["class"] for obj in annotations]
+        classes = [None] * len(class_names)
+        for i in range(len(class_names)):
+            classes[i] = self.labelmap['class_to_ind'][class_names[i]]
+        return torch.tensor(classes)
+
+    def add_confidences(self, annotations):
+        confidences = []
+        for obj in annotations:
+            if "conf" in obj:
+                confidences.append(obj["conf"])
+            else:
+                confidences.append(1.0)
+        return torch.tensor(confidences)
+
+    def add_attributes(self, annotations):
+        # the maximal number of attributes per object is 16
+        attributes = [[0] * 16 for _ in range(len(annotations))]
+        for i, obj in enumerate(annotations):
+            for j, attr in enumerate(obj["attributes"]):
+                attributes[i][j] = self.labelmap['attribute_to_ind'][attr]
+        return torch.tensor(attributes)
+
+    def add_features(self, annotations):
+        features = []
+        for obj in annotations:
+            features.append(np.frombuffer(base64.b64decode(obj['feature']), np.float32))
+        return torch.tensor(features)
+
+    def add_scores_all(self, annotations):
+        scores_all = []
+        for obj in annotations:
+            scores_all.append(np.frombuffer(base64.b64decode(obj['scores_all']), np.float32))
+        return torch.tensor(scores_all)
+
+    def add_boxes_all(self, annotations):
+        boxes_all = []
+        for obj in annotations:
+            boxes_all.append(np.frombuffer(base64.b64decode(obj['boxes_all']), np.float32).reshape(-1, 4))
+        return torch.tensor(boxes_all)
+
+    def relation_loader(self, relation_annos, target):
+        if self.filter_duplicate_relations:
+            # Filter out dupes!
+            all_rel_sets = collections.defaultdict(list)
+            for triplet in relation_annos:
+                all_rel_sets[(triplet['subj_id'], triplet['obj_id'])].append(triplet)
+            relation_annos = [np.random.choice(v) for v in all_rel_sets.values()]
+
+        # get M*M pred_labels
+        relation_triplets = []
+        relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
+        for i in range(len(relation_annos)):
+            if len(self.ignore_rel) != 0 and relation_annos[i]['class'] in self.ignore_rel:
+                continue
+            subj_id = relation_annos[i]['subj_id']
+            obj_id = relation_annos[i]['obj_id']
+            predicate = self.labelmap['relation_to_ind'][relation_annos[i]['class']]
+            relations[subj_id, obj_id] = predicate
+            relation_triplets.append([subj_id, obj_id, predicate])
+
+        relation_triplets = torch.tensor(relation_triplets)
+        target.add_field("relation_labels", relation_triplets)
+        target.add_field("pred_labels", relations)
+        return target
+
+
+class BoxLabelLoader(object):
+    def __init__(self, labelmap, extra_fields=(), ignore_attrs=(),
+                 mask_mode="poly"):
+        self.labelmap = labelmap
+        self.extra_fields = extra_fields
+        self.ignore_attrs = ignore_attrs
+        assert mask_mode == "poly" or mask_mode == "mask"
+        self.mask_mode = mask_mode
+        self.all_fields = ["class", "mask", "confidence",
+                           "attributes_encode", "IsGroupOf", "IsProposal"]
+
+    def __call__(self, annotations, img_size, remove_empty=True):
+        boxes = [obj["rect"] for obj in annotations]
+        boxes = torch.as_tensor(boxes).reshape(-1, 4)
+        target = BoxList(boxes, img_size, mode="xyxy")
+
+        for field in self.extra_fields:
+            assert field in self.all_fields, "Unsupported field {}".format(field)
+            if field == "class":
+                classes = self.add_classes_with_ignore(annotations)
+                target.add_field("labels", classes)
+            elif field == "mask":
+                masks, is_box_mask = self.add_masks(annotations, img_size)
+                target.add_field("masks", masks)
+                target.add_field("is_box_mask", is_box_mask)
+            elif field == "confidence":
+                confidences = self.add_confidences(annotations)
+                target.add_field("confidences", confidences)
+            elif field == "attributes_encode":
+                attributes = self.add_attributes(annotations)
+                target.add_field("attributes", attributes)
+            elif field == "IsGroupOf":
+                is_group = [1 if 'IsGroupOf' in obj and obj['IsGroupOf'] == 1 else 0
+                            for obj in annotations]
+                target.add_field("IsGroupOf", torch.tensor(is_group))
+            elif field == "IsProposal":
+                is_proposal = [1 if "IsProposal" in obj and obj['IsProposal'] == 1 else 0
+                               for obj in annotations]
+                target.add_field("IsProposal", torch.tensor(is_proposal))
+
+        target = target.clip_to_image(remove_empty=remove_empty)
+        return target
+
+    def add_classes_with_ignore(self, annotations):
+        class_names = [obj["class"] for obj in annotations]
+        classes = [None] * len(class_names)
+        if self.ignore_attrs:
+            for i, obj in enumerate(annotations):
+                if any([obj[attr] for attr in self.ignore_attrs if attr in obj]):
+                    classes[i] = -1
+        for i, cls in enumerate(classes):
+            if cls != -1:
+                classes[i] = self.labelmap[class_names[i]] + 1  # 0 is saved for background
+        return torch.tensor(classes)
+
+    def add_masks(self, annotations, img_size):
+        masks = []
+        is_box_mask = []
+        for obj in annotations:
+            if "mask" in obj:
+                masks.append(obj["mask"])
+                is_box_mask.append(0)
+            else:
+                masks.append(self.get_box_mask(obj["rect"], img_size))
+                is_box_mask.append(1)
+        masks = SegmentationMask(masks, img_size, mode=self.mask_mode)
+        is_box_mask = torch.tensor(is_box_mask)
+        return masks, is_box_mask
+
+    def get_box_mask(self, rect, img_size):
+        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
+        if self.mask_mode == "poly":
+            return [[x1, y1, x1, y2, x2, y2, x2, y1]]
+        elif self.mask_mode == "mask":
+            # note the order of height/width order in mask is opposite to image
+            mask = np.zeros([img_size[1], img_size[0]], dtype=np.uint8)
+            mask[math.floor(y1):math.ceil(y2), math.floor(x1):math.ceil(x2)] = 255
+            encoded_mask = mask_utils.encode(np.asfortranarray(mask))
+            encoded_mask["counts"] = encoded_mask["counts"].decode("utf-8")
+            return encoded_mask
+
+    def add_confidences(self, annotations):
+        confidences = []
+        for obj in annotations:
+            if "confidence" in obj:
+                confidences.append(obj["confidence"])
+            elif "conf" in obj:
+                confidences.append(obj["conf"])
+            else:
+                confidences.append(1.0)
+        return torch.tensor(confidences)
+
+    def add_attributes(self, annotations):
+        # we know that the maximal number of attributes per object is 16
+        attributes = [[0] * 16 for _ in range(len(annotations))]
+        for i, obj in enumerate(annotations):
+            attributes[i][:len(obj["attributes_encode"])] = obj["attributes_encode"]
+        return torch.tensor(attributes)
diff --git a/maskrcnn_benchmark/data/datasets/caption.py b/maskrcnn_benchmark/data/datasets/caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4a5ec88ab02e589a0333a5d65f907b545d710ed
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/caption.py
@@ -0,0 +1,279 @@
+import torch
+import torch.distributed as dist
+import time
+from torchvision.ops import nms
+import random
+import numpy as np
+from PIL import Image, ImageDraw
+import pdb
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from .modulated_coco import ConvertCocoPolysToMask
+from .tsv import ODTSVDataset, TSVYamlDataset
+from .od_to_grounding import sanity_check_target_after_processing
+
+class CaptionTSV(TSVYamlDataset):
+    def __init__(self,
+                 yaml_file,
+                 transforms,
+                 return_tokens,
+                 return_masks,
+                 tokenizer,
+                 caption_min_box=1,
+                 replace_clean_label=False,
+                 further_screen=False,
+                 caption_conf=0.5,
+                 caption_nms=-1,
+                 pack_random_caption_number=0,
+                 inference_caption=False,
+                 sample_negative_for_grounding_data=-1,
+                 random_pack_prob=-1.0,
+                 no_random_pack_probability=0.0,
+                 safeguard_positive_caption=True,
+                 mlm_obj_for_only_positive=False,
+                 caption_format_version="v1",
+                 local_debug=False,
+                 max_query_len=256,
+                 **kwargs
+                 ):
+        super(CaptionTSV, self).__init__(yaml_file, None, replace_clean_label)
+        self.yaml_file = yaml_file
+        self._transforms = transforms
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(return_masks=return_masks,
+                                              return_tokens=return_tokens,
+                                              tokenizer=tokenizer,
+                                              max_query_len=max_query_len)
+        self.tokenizer = tokenizer
+        self.caption_min_box = caption_min_box
+        self.replace_clean_label = replace_clean_label
+        self.further_screen = further_screen
+        self.pack_random_caption_number = pack_random_caption_number
+        self.caption_format_version = caption_format_version
+
+        self.caption_conf = caption_conf
+        self.caption_nms = caption_nms
+        self.inference_caption = inference_caption
+        self.sample_negative_for_grounding_data = sample_negative_for_grounding_data
+        self.random_pack_prob = random_pack_prob
+        self.no_random_pack_probability = no_random_pack_probability
+        self.safeguard_positive_caption = safeguard_positive_caption
+        self.mlm_obj_for_only_positive = mlm_obj_for_only_positive
+        try:
+            self.rank = dist.get_rank()
+        except:
+            self.rank = 0
+
+    def __len__(self):
+        return super(CaptionTSV, self).__len__()
+
+    def pack_caption(self, positive_caption, negative_captions, original_tokens_positive):
+        if len(negative_captions) == 0:
+            return positive_caption, original_tokens_positive, [(0, len(positive_caption))]
+        if self.safeguard_positive_caption:
+            length_of_each_caption = []
+            for caption in negative_captions + [positive_caption]:
+                tokenized = self.tokenizer(caption, return_tensors="pt")
+                length_of_each_caption.append(tokenized.input_ids.size(-1))
+            max_length = self.max_query_len - length_of_each_caption[-1]
+            indexes = list(range(len(negative_captions)))
+            random.shuffle(indexes)
+            new_caption_list = [positive_caption]
+            for i in indexes:
+                if length_of_each_caption[i] < max_length:
+                    new_caption_list.append(negative_captions[i])
+                    max_length -= length_of_each_caption[i]
+        else:
+            new_caption_list = [positive_caption] + negative_captions
+        random.shuffle(new_caption_list)
+
+        new_caption = ''
+
+        for i in new_caption_list:
+            if i == positive_caption:
+                start_position = len(new_caption)
+            new_caption += i
+            if not i.endswith("."):
+                new_caption += "."
+            new_caption += " "
+
+        # shift the token positions the boxes are aligned to
+        for index, i in enumerate(original_tokens_positive):
+            original_tokens_positive[index] = [tuple(j) for j in i]
+        for i in original_tokens_positive:
+            for index, j in enumerate(i):
+                i[index] = (j[0] + start_position, j[1] + start_position)
+
+        return new_caption, original_tokens_positive, [(start_position, start_position + len(positive_caption))]
+
+    def __get_negative_captions__(self, idx, negative_size=7):
+        negative_captions = []
+        for i in range(negative_size):
+            img, anno, _, scale = super(CaptionTSV, self).__getitem__(np.random.choice(len(self)))
+            caption = anno["caption"]
+            negative_captions.append(caption)
+
+        return negative_captions
+
+    def __getitem__(self, idx):
+        try:
+            img, anno, _, scale = super(CaptionTSV, self).__getitem__(idx)
+            if self.inference_caption:
+                caption = None
+                if isinstance(anno, list):
+                    caption = anno[0]["caption"]  # inference mode for bing
+                    anno = []
+                elif len(anno) == 1:
+                    caption = anno["caption"]  # inference mode for googlecc
+                    anno = []
+                else:
+                    caption = " ".join(anno["captions"])
+                    anno = []
+            else:
+                '''
+                An example
+                {'img_h': 1154, 'img_w': 1600, 'caption': 'xxx', 'tokens_positive': [[[47, 50], [51, 53], [54, 59]], [[32, 35], [36, 41]], [[32, 35], [36, 41]], [[0, 3], [3, 6], [6, 10], [11, 16], [17, 19], [20, 23]], [[32, 35], [36, 41]], [[32, 35], [36, 41]]], 'bboxes': [[7.344961166381836, 10.479412078857422, 1592.2679443359375, 1090.0028076171875], [950.32861328125, 346.572021484375, 1333.2373046875, 679.3215942382812], [927.44140625, 342.7712707519531, 1389.833984375, 719.5758666992188], [90.48786163330078, 363.67572021484375, 1381.8631591796875, 1078.687744140625], [122.84217071533203, 422.6786193847656, 507.845703125, 667.2651977539062], [80.62384033203125, 416.500244140625, 563.1666259765625, 734.603271484375]], 'scores': [0.7966700196266174, 0.8952182531356812, 0.8186006546020508, 0.9995516538619995, 0.8021856546401978, 0.8923134803771973]}
+                '''
+                if len(anno["bboxes"]) < self.caption_min_box:  # Retry triggered!
+                    return self[np.random.choice(len(self))]
+
+                if self.caption_format_version == "v2":
+                    anno = self.convert_anno_from_v2_to_v1(anno)
+
+                try:
+                    if self.further_screen:
+                        conf = self.caption_conf
+                        nms_thre = self.caption_nms
+
+                        bboxes = torch.as_tensor(anno["bboxes"]).float()
+                        scores = torch.as_tensor(anno["scores"])
+                        tokens_positive = anno["tokens_positive"]
+
+                        # print("\n\n\n\n tokens_positive in original data", tokens_positive)
+
+                        keep = scores > conf
+                        scores = scores[keep]
+                        bboxes = bboxes[keep]
+                        tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]]
+
+                        assert (len(tokens_positive) == len(bboxes) == len(scores))
+
+                        if len(bboxes) < self.caption_min_box:  # Retry triggered!
+                            return self[np.random.choice(len(self))]
+
+                        if nms_thre > 0:
+                            keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre)
+                            scores = scores[keep]
+                            bboxes = bboxes[keep]
+                            tokens_positive = [tokens_positive[i] for i in keep]
+                            assert (len(tokens_positive) == len(bboxes) == len(scores))
+
+                        # Write back
+                        anno["bboxes"] = bboxes.tolist()
+                        anno["scores"] = scores.tolist()
+                        anno["tokens_positive"] = tokens_positive
+
+                    boxes = torch.as_tensor(anno["bboxes"])
+
+                    if len(boxes) < self.caption_min_box:  # Retry triggered!
+                        return self[np.random.choice(len(self))]
+
+                    target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy")
+                    target = target.clip_to_image(remove_empty=True)
+
+                    caption = anno["caption"]
+                    # print("original caption", caption)
+                    empty_everything = False
+                    if self.sample_negative_for_grounding_data != -1:
+                        if random.random() < self.sample_negative_for_grounding_data:
+                            empty_everything = True
+
+                    if empty_everything:
+                        caption = self.__get_negative_captions__(idx, negative_size=1)[0]
+
+                    if self.pack_random_caption_number != 0:
+                        if self.random_pack_prob != -1.0:
+                            if random.random() < self.no_random_pack_probability:
+                                negative_pack_number = 0
+                            elif random.random() < self.random_pack_prob:
+                                negative_pack_number = self.pack_random_caption_number
+                            else:
+                                negative_pack_number = np.random.choice(self.pack_random_caption_number)
+                        else:
+                            negative_pack_number = self.pack_random_caption_number
+
+                        negative_captions = self.__get_negative_captions__(idx, negative_size=negative_pack_number)
+
+                        caption, anno["tokens_positive"], greenlight_span_for_masked_lm_objective = self.pack_caption(
+                            caption, negative_captions, anno["tokens_positive"])
+                    else:
+                        greenlight_span_for_masked_lm_objective = [(0, len(caption))]
+
+                    if not self.mlm_obj_for_only_positive:
+                        greenlight_span_for_masked_lm_objective = [(0, len(caption))]
+                    
+                    new_anno = []
+                    areas = target.area()
+                    for i in range(len(target)):
+                        new_anno_i = {}
+                        new_anno_i["area"] = areas[i]
+                        new_anno_i["iscrowd"] = 0
+                        new_anno_i["image_id"] = idx
+                        new_anno_i["category_id"] = 1  # following vg and others
+                        new_anno_i["id"] = None
+                        new_anno_i['bbox'] = target.bbox[i].numpy().tolist()
+                        new_anno_i["tokens_positive"] = anno["tokens_positive"][i]
+                        new_anno.append(new_anno_i)
+
+                except:
+                    return self[np.random.choice(len(self))]
+
+                anno = new_anno
+                if empty_everything:
+                    anno = []
+
+            annotations = {"image_id": idx, "annotations": anno, "caption": caption}
+            annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
+            img, annotations = self.prepare(img, annotations, box_format="xyxy")
+
+            if self._transforms is not None:
+                img, target = self._transforms(img, target)
+
+            # add additional property
+            for ann in annotations:
+                target.add_field(ann, annotations[ann])
+        except:
+            print("Outter Retry triggered!!")
+            return self[np.random.choice(len(self))]
+
+        sanity_check_target_after_processing(target)
+        
+        return img, target, idx
+
+    def convert_anno_from_v2_to_v1(self, anno):
+        flatterned_bboxes = []
+        flatterned_tokens_positive = []
+        flatterned_bboxes_scores = []
+        for i in range(len(anno["bboxes"])):
+            # i is the index for entity
+            for j in range(len(anno["bboxes"][i])):
+                # j is the index for each box
+                flatterned_bboxes.append(anno["bboxes"][i][j])
+                flatterned_tokens_positive.append(
+                    anno["tokens_positive"][i])  # Assume this box corresponds to all the token_spans for this entity
+                flatterned_bboxes_scores.append(anno["scores"][i][j])
+        anno["bboxes"] = flatterned_bboxes
+        anno["tokens_positive"] = flatterned_tokens_positive
+        anno["scores"] = flatterned_bboxes_scores
+        return anno
+
+
+    def get_raw_image(self, idx):
+        image, *_ = super(CaptionTSV, self).__getitem__(idx)
+        return image
+
+    def get_img_id(self, idx):
+        line_no = self.get_line_no(idx)
+        if self.label_tsv is not None:
+            row = self.label_tsv.seek(line_no)
+            img_id = row[0]
+            return img_id
diff --git a/maskrcnn_benchmark/data/datasets/coco.py b/maskrcnn_benchmark/data/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..095af9ea67f08acb93fe4d6b175708cca60809ea
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/coco.py
@@ -0,0 +1,268 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
+import os.path
+import math
+from PIL import Image, ImageDraw
+
+import random
+import numpy as np
+
+import torch
+import torchvision
+import torch.utils.data as data
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
+from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
+from maskrcnn_benchmark.config import cfg
+import pdb
+
+def _count_visible_keypoints(anno):
+    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
+
+
+def _has_only_empty_bbox(anno):
+    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+
+def has_valid_annotation(anno):
+    # if it's empty, there is no annotation
+    if len(anno) == 0:
+        return False
+    # if all boxes have close to zero area, there is no annotation
+    if _has_only_empty_bbox(anno):
+        return False
+    # keypoints task have a slight different critera for considering
+    # if an annotation is valid
+    if "keypoints" not in anno[0]:
+        return True
+    # for keypoint detection tasks, only consider valid images those
+    # containing at least min_keypoints_per_image
+    if _count_visible_keypoints(anno) >= cfg.DATALOADER.MIN_KPS_PER_IMS:
+        return True
+    return False
+
+
+def pil_loader(path, retry=5):
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    ri = 0
+    while ri < retry:
+        try:
+            with open(path, 'rb') as f:
+                img = Image.open(f)
+                return img.convert('RGB')
+        except:
+            ri += 1
+
+
+def rgb2id(color):
+    if isinstance(color, np.ndarray) and len(color.shape) == 3:
+        if color.dtype == np.uint8:
+            color = color.astype(np.int32)
+        return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
+    return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
+
+
+class CocoDetection(data.Dataset):
+    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(self, root, annFile, transform=None, target_transform=None):
+        from pycocotools.coco import COCO
+        self.root = root
+        self.coco = COCO(annFile)
+        self.ids = list(self.coco.imgs.keys())
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __getitem__(self, index, return_meta=False):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        coco = self.coco
+        img_id = self.ids[index]
+        if isinstance(img_id, str):
+            img_id = [img_id]
+        ann_ids = coco.getAnnIds(imgIds=img_id)
+        target = coco.loadAnns(ann_ids)
+
+        meta = coco.loadImgs(img_id)[0]
+        path = meta['file_name']
+        img = pil_loader(os.path.join(self.root, path))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        if return_meta:
+            return img, target, meta
+        else:
+            return img, target
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __repr__(self):
+        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
+        fmt_str += '    Root Location: {}\n'.format(self.root)
+        tmp = '    Transforms (if any): '
+        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        tmp = '    Target Transforms (if any): '
+        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        return fmt_str
+
+
+class COCODataset(CocoDetection):
+    def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None, ignore_crowd=True,
+                 max_box=-1,
+                 few_shot=0, one_hot=False, override_category=None, **kwargs
+                 ):
+        super(COCODataset, self).__init__(root, ann_file)
+        # sort indices for reproducible results
+        self.ids = sorted(self.ids)
+
+        # filter images without detection annotations
+        if remove_images_without_annotations:
+            ids = []
+            for img_id in self.ids:
+                if isinstance(img_id, str):
+                    ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+                else:
+                    ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+                anno = self.coco.loadAnns(ann_ids)
+                if has_valid_annotation(anno):
+                    ids.append(img_id)
+            self.ids = ids
+
+        if few_shot:
+            ids = []
+            cats_freq = [few_shot]*len(self.coco.cats.keys())
+            if 'shuffle_seed' in kwargs and kwargs['shuffle_seed'] != 0:
+                import random
+                random.Random(kwargs['shuffle_seed']).shuffle(self.ids)
+                print("Shuffle the dataset with random seed: ", kwargs['shuffle_seed'])
+            for img_id in self.ids:
+                if isinstance(img_id, str):
+                    ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+                else:
+                    ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+                anno = self.coco.loadAnns(ann_ids)
+                cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level
+                is_needed = sum([cats_freq[c-1]>0 for c in cat])
+                if is_needed:
+                    ids.append(img_id)
+                    for c in cat:
+                        cats_freq[c-1] -= 1
+                    # print(cat, cats_freq)
+            self.ids = ids
+        
+        if override_category is not None:
+            self.coco.dataset["categories"] = override_category
+            print("Override category: ", override_category)
+
+        self.json_category_id_to_contiguous_id = {
+            v: i + 1 for i, v in enumerate(self.coco.getCatIds())
+        }
+        self.contiguous_category_id_to_json_id = {
+            v: k for k, v in self.json_category_id_to_contiguous_id.items()
+        }
+        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
+        self.transforms = transforms
+        self.ignore_crowd = ignore_crowd
+        self.max_box = max_box
+        self.one_hot = one_hot
+
+    def categories(self, no_background=True):
+        categories = self.coco.dataset["categories"]
+        label_list = {}
+        for index, i in enumerate(categories):
+            if not no_background or (i["name"] != "__background__" and i['id'] != 0):
+                label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"]
+        return label_list
+
+    def __getitem__(self, idx):
+
+        
+        img, anno = super(COCODataset, self).__getitem__(idx)
+
+        # filter crowd annotations
+        if self.ignore_crowd:
+            anno = [obj for obj in anno if obj["iscrowd"] == 0]
+
+        boxes = [obj["bbox"] for obj in anno]
+        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
+        if self.max_box > 0 and len(boxes) > self.max_box:
+            rand_idx = torch.randperm(self.max_box)
+            boxes = boxes[rand_idx, :]
+        else:
+            rand_idx = None
+        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
+
+        classes = [obj["category_id"] for obj in anno]
+        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
+        classes = torch.tensor(classes)
+
+        if rand_idx is not None:
+            classes = classes[rand_idx]
+        if cfg.DATASETS.CLASS_AGNOSTIC:
+            classes = torch.ones_like(classes)
+        target.add_field("labels", classes)
+
+        if anno and "segmentation" in anno[0]:
+            masks = [obj["segmentation"] for obj in anno]
+            masks = SegmentationMask(masks, img.size, mode='poly')
+            target.add_field("masks", masks)
+
+        if anno and "cbox" in anno[0]:
+            cboxes = [obj["cbox"] for obj in anno]
+            cboxes = torch.as_tensor(cboxes).reshape(-1, 4)  # guard against no boxes
+            cboxes = BoxList(cboxes, img.size, mode="xywh").convert("xyxy")
+            target.add_field("cbox", cboxes)
+
+        if anno and "keypoints" in anno[0]:
+            keypoints = []
+            gt_keypoint = self.coco.cats[1]['keypoints']  # <TODO> a better way to get keypoint description
+            use_keypoint = cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME
+            for obj in anno:
+                if len(use_keypoint) > 0:
+                    kps = []
+                    for name in use_keypoint:
+                        kp_idx = slice(3 * gt_keypoint.index(name), 3 * gt_keypoint.index(name) + 3)
+                        kps += obj["keypoints"][kp_idx]
+                    keypoints.append(kps)
+                else:
+                    keypoints.append(obj["keypoints"])
+            keypoints = PersonKeypoints(keypoints, img.size)
+            target.add_field("keypoints", keypoints)
+
+        target = target.clip_to_image(remove_empty=True)
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        if cfg.DATASETS.SAMPLE_RATIO != 0.0:
+            ratio = cfg.DATASETS.SAMPLE_RATIO
+            num_sample_target = math.ceil(len(target) * ratio) if ratio > 0 else math.ceil(-ratio)
+            sample_idx = torch.randperm(len(target))[:num_sample_target]
+            target = target[sample_idx]
+        return img, target, idx
+
+    def get_img_info(self, index):
+        img_id = self.id_to_img_map[index]
+        img_data = self.coco.imgs[img_id]
+        return img_data
diff --git a/maskrcnn_benchmark/data/datasets/coco_dt.py b/maskrcnn_benchmark/data/datasets/coco_dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b050b3ed0dd2fa5b4974ef17fc8bdb26ed08fee7
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/coco_dt.py
@@ -0,0 +1,154 @@
+"""
+COCO dataset which returns image_id for evaluation.
+
+Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
+"""
+
+import torch
+import json
+from PIL import Image, ImageDraw
+
+from .modulated_coco import ConvertCocoPolysToMask
+from .tsv import ODTSVDataset
+from pycocotools.coco import COCO
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+import random
+from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing
+
+
+class CocoDetectionTSV(ODTSVDataset):
+    def __init__(self,
+                 name,
+                 yaml_file,
+                 transforms,
+                 return_tokens,
+                 tokenizer,
+                 extra_fields,
+                 random_sample_negative=-1,
+                 add_detection_prompt=False,
+                 add_detection_prompt_advanced=False,
+                 use_od_data_aug=False,
+                 control_probabilities={},
+                 disable_shuffle=False,
+                 prompt_engineer_version="v2",
+                 prompt_limit_negative=-1,
+                 positive_question_probability=0.6,
+                 negative_question_probability=0.8,
+                 full_question_probability=0.5,
+                 disable_clip_to_image=False,
+                 separation_tokens=" ",
+                 no_mask_for_od=False,
+                 max_num_labels=-1,
+                 max_query_len=256,
+                 **kwargs
+                 ):
+        super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs)
+
+        self._transforms = transforms
+        self.name = name
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(
+            return_masks=False,
+            return_tokens=return_tokens,
+            tokenizer=tokenizer,
+            max_query_len=max_query_len
+        )
+        self.tokenizer = tokenizer
+
+        self.control_probabilities = control_probabilities
+        self.random_sample_negative = random_sample_negative
+        self.add_detection_prompt = add_detection_prompt
+        self.add_detection_prompt_advanced = add_detection_prompt_advanced
+        self.use_od_data_aug = use_od_data_aug
+
+        self.prompt_engineer_version = prompt_engineer_version
+        self.prompt_limit_negative = prompt_limit_negative
+        self.positive_question_probability = positive_question_probability
+        self.negative_question_probability = negative_question_probability
+        self.full_question_probability = full_question_probability
+        self.separation_tokens = separation_tokens
+        self.disable_clip_to_image = disable_clip_to_image
+        self.disable_shuffle = disable_shuffle
+        self.no_mask_for_od = no_mask_for_od
+        self.max_num_labels = max_num_labels
+
+    def __len__(self):
+        return super(CocoDetectionTSV, self).__len__()
+
+    def categories(self, no_background=True):
+        categories = self.coco.dataset["categories"]
+        label_list = {}
+        for index, i in enumerate(categories):
+            # assert(index + 1 == i["id"])
+            if not no_background or (i["name"] != "__background__" and i['id'] != 0):
+                label_list[i["id"]] = i["name"]
+        return label_list
+
+    def __getitem__(self, idx):
+        # tgt is a BoxList
+        img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx)
+        image_id = self.get_img_id(idx)
+        restricted_negative_list = None
+
+        if not self.disable_clip_to_image:
+            target = target.clip_to_image(remove_empty=True)
+
+        original_box_num = len(target)
+
+        target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens
+
+        if len(target) < original_box_num:
+            print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))
+
+        annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
+            target=target,
+            image_id=image_id,
+            ind_to_class=self.ind_to_class,
+            disable_shuffle=self.disable_shuffle,
+            add_detection_prompt=self.add_detection_prompt,
+            add_detection_prompt_advanced=self.add_detection_prompt_advanced,
+            random_sample_negative=self.random_sample_negative,
+            control_probabilities=self.control_probabilities,
+            restricted_negative_list=restricted_negative_list,
+            separation_tokens=self.separation_tokens,
+            max_num_labels=self.max_num_labels,
+            positive_caption_length=positive_caption_length,
+            tokenizer=self.tokenizer,
+            max_seq_length=self.max_query_len-2
+        )
+
+        # assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2)
+
+        # print(caption)
+        anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions": label_to_positions}
+        anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
+
+        if self.no_mask_for_od:
+            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
+
+        img, anno = self.prepare(img, anno, box_format="xyxy")
+
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+        
+        # add additional property
+        for ann in anno:
+            target.add_field(ann, anno[ann])
+
+        sanity_check_target_after_processing(target)
+
+        return img, target, idx
+
+    def get_raw_image(self, idx):
+        image, *_ = super(CocoDetectionTSV, self).__getitem__(idx)
+        return image
+
+    def get_img_id(self, idx):
+        line_no = self.get_line_no(idx)
+        if self.label_tsv is not None:
+            row = self.label_tsv.seek(line_no)
+            img_id = row[0]
+            try:
+                return int(img_id)
+            except:
+                return idx
diff --git a/maskrcnn_benchmark/data/datasets/concat_dataset.py b/maskrcnn_benchmark/data/datasets/concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cb6c0d96f906056c0b6d0d001db00c6eac2a5ae
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/concat_dataset.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import bisect
+
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+
+class ConcatDataset(_ConcatDataset):
+    """
+    Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra
+    method for querying the sizes of the image
+    """
+
+    def get_idxs(self, idx):
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        return dataset_idx, sample_idx
+
+    def get_img_info(self, idx):
+        dataset_idx, sample_idx = self.get_idxs(idx)
+        return self.datasets[dataset_idx].get_img_info(sample_idx)
diff --git a/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py b/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf8d8c4ea2b5f603d4a3a94cb114c154b56566d
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py
@@ -0,0 +1,185 @@
+import math
+from typing import TypeVar, Optional, Iterator
+
+import torch
+from torch.utils.data import Sampler, Dataset
+import torch.distributed as dist
+import random
+import numpy as np
+import torch
+
+
+class DistributedSamplerChunkByNode(torch.utils.data.Sampler):
+
+    def __init__(self,
+                 dataset,
+                 all_datasets,
+                 chunk_or_not,
+                 num_replicas: Optional[int] = None,
+                 rank: Optional[int] = None,
+                 shuffle: bool = True,
+                 seed: int = 0,
+                 drop_last: bool = False,
+                 node_rank=0,
+                 node_number=1, process_num_per_node=1,
+                 rank_within_local_node=0) -> None:
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        if rank >= num_replicas or rank < 0:
+            raise ValueError(
+                "Invalid rank {}, rank should be in the interval"
+                " [0, {}]".format(rank, num_replicas - 1))
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.node_number = node_number
+        self.node_rank = node_rank
+        self.chunk_or_not = chunk_or_not
+        self.process_num_per_node = process_num_per_node
+        self.rank_within_local_node = rank_within_local_node
+
+        assert (self.process_num_per_node * self.node_number == self.num_replicas)
+
+        # 1. divide the datasets into two parts
+        normal_datasets = []
+        chunked_datasets = []
+        for dataset_i, chunk_i in zip(all_datasets, chunk_or_not):
+            if chunk_i:
+                chunked_datasets.append(dataset_i)
+            else:
+                normal_datasets.append(dataset_i)
+
+        # 2. calculate dataset sizes:
+        self.normal_dataset_size = sum(
+            [len(i) for i in normal_datasets])  # this part we follow the conventional distributed sampler
+
+        # 3. Divide 
+        self.current_node_start_range = -1
+        self.current_node_end_range = -1
+        assert (len(chunked_datasets) >= self.node_number)
+        chunk_size = len(chunked_datasets) // self.node_number
+        current_example_num = self.normal_dataset_size
+
+        for index in range(len(chunked_datasets)):
+            if index == self.node_rank * chunk_size:
+                self.current_node_start_range = current_example_num
+            current_example_num += len(chunked_datasets[index])
+            if index == (self.node_rank + 1) * chunk_size - 1:
+                self.current_node_end_range = current_example_num
+
+        if self.current_node_end_range == -1:  # boundary
+            self.current_node_end_range = current_example_num
+
+        self.drop_last = drop_last
+        # If the dataset length is evenly divisible by # of replicas, then there
+        # is no need to drop any data, since the dataset will be split equally.
+        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
+            # Split to nearest available length that is evenly divisible.
+            # This is to ensure each rank receives the same amount of data when
+            # using this Sampler.
+            self.num_samples = math.ceil(
+                # `type:ignore` is required because Dataset cannot provide a default __len__
+                # see NOTE in pytorch/torch/utils/data/sampler.py
+                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
+            )
+        else:
+            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
+        self.total_size = self.num_samples * self.num_replicas
+        self.shuffle = shuffle
+        self.seed = seed
+
+    def __iter__(self):
+        indices = self.generate_indices_within_range_with_rank(
+            seed=self.seed,
+            epoch=self.epoch,
+
+            # NOTE: Distribute among all processes
+            process_num=self.num_replicas,
+            rank=self.rank,
+            generate_length=-1,
+            valid_indices=list(range(self.normal_dataset_size)),
+            prefix="Normal "
+        )
+
+        addition_indices = self.generate_indices_within_range_with_rank(
+            seed=self.seed,
+            epoch=self.epoch,
+
+            # NOTE : very important arguments, distribute among local nodes
+            process_num=self.process_num_per_node,
+            rank=self.rank_within_local_node,
+
+            generate_length=self.num_samples - len(indices),
+            valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)),
+            prefix="Distribute "
+        )
+
+        indices.extend(addition_indices)
+        random.seed(self.seed + self.epoch + 10 * self.rank)  # Set the seed to maximize randomness
+        random.shuffle(indices)  # Reshuffle
+        assert len(indices) == self.num_samples
+        return iter(indices)
+
+    def generate_indices_within_range_with_rank(self, seed, epoch, process_num, generate_length, valid_indices, rank=-1,
+                                                shuffle=True, prefix=""):
+        '''
+        Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process.
+        Modified from DistributedSampler
+        '''
+        dataset_size = len(valid_indices)
+        if shuffle:
+            # deterministically shuffle based on epoch and seed
+            g = torch.Generator()
+            g.manual_seed(seed + epoch)
+            indices = torch.randperm(dataset_size, generator=g).tolist()  # type: ignore[arg-type]
+        else:
+            indices = list(range(dataset_size))  # type: ignore[arg-type]
+
+        indices = [valid_indices[i] for i in indices]
+
+        num_samples_normal = math.ceil(
+            (dataset_size - process_num) / process_num  # type: ignore[arg-type]
+        )
+        # remove tail of data to make it evenly divisible.
+        indices = indices[:num_samples_normal * process_num]
+
+        print("\n")
+        print(prefix,
+              "Global Rank {}   Local Rank {}    generate_length {}    valid_indices {}    process_num {}  indices_before_subsample {} {}".format(
+                  self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))
+
+        # subsample
+        indices = indices[rank:num_samples_normal * process_num: process_num]
+
+        print(prefix,
+              "Global Rank {}   Local Rank {}    generate_length {}    valid_indices {}    process_num {}  indices_after_subsample {} {}".format(
+                  self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))
+        print("\n")
+
+        if generate_length != -1:
+            if len(indices) > generate_length:
+                indices = indices[:generate_length]
+            else:
+                indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist())
+        return indices
+
+    def __len__(self) -> int:
+        return self.num_samples
+
+    def set_epoch(self, epoch: int) -> None:
+        r"""
+        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
+        use a different random ordering for each epoch. Otherwise, the next iteration of this
+        sampler will yield the same ordering.
+
+        Args:
+            epoch (int): Epoch number.
+        """
+        self.epoch = epoch
diff --git a/maskrcnn_benchmark/data/datasets/duplicate_dataset.py b/maskrcnn_benchmark/data/datasets/duplicate_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3506c968d496bdc90954418c8f955f6012beb6
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/duplicate_dataset.py
@@ -0,0 +1,31 @@
+import math
+from typing import TypeVar, Optional, Iterator
+
+import torch
+from torch.utils.data import Sampler, Dataset
+import torch.distributed as dist
+import random
+import numpy as np
+
+
+def create_duplicate_dataset(DatasetBaseClass):
+    class DupDataset(DatasetBaseClass):
+
+        def __init__(self, copy, **kwargs):
+            super(DupDataset, self).__init__(**kwargs)
+
+            self.copy = copy
+            self.length = super(DupDataset, self).__len__()
+
+        def __len__(self):
+            return self.copy * self.length
+
+        def __getitem__(self, index):
+            true_index = index % self.length
+            return super(DupDataset, self).__getitem__(true_index)
+
+        def get_img_info(self, index):
+            true_index = index % self.length
+            return super(DupDataset, self).get_img_info(true_index)
+
+    return DupDataset
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d84ead6b86ddcc1e0ae4088a5e36546ebef0efd
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/__init__.py
@@ -0,0 +1,56 @@
+from maskrcnn_benchmark.data import datasets
+
+from .coco import coco_evaluation
+from .voc import voc_evaluation
+from .vg import vg_evaluation
+from .box_aug import im_detect_bbox_aug
+from .od_to_grounding import od_to_grounding_evaluation
+
+
+def evaluate(dataset, predictions, output_folder, **kwargs):
+    """evaluate dataset using different methods based on dataset type.
+    Args:
+        dataset: Dataset object
+        predictions(list[BoxList]): each item in the list represents the
+            prediction results for one image.
+        output_folder: output folder, to save evaluation files or results.
+        **kwargs: other args.
+    Returns:
+        evaluation result
+    """
+    args = dict(
+        dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs
+    )
+    if isinstance(dataset, datasets.COCODataset) or isinstance(dataset, datasets.TSVDataset):
+        return coco_evaluation(**args)
+    # elif isinstance(dataset, datasets.VGTSVDataset):
+    #     return vg_evaluation(**args)
+    elif isinstance(dataset, datasets.PascalVOCDataset):
+        return voc_evaluation(**args)
+    elif isinstance(dataset, datasets.CocoDetectionTSV):
+        return od_to_grounding_evaluation(**args)
+    elif isinstance(dataset, datasets.LvisDetection):
+        pass
+    else:
+        dataset_name = dataset.__class__.__name__
+        raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name))
+
+
+def evaluate_mdetr(dataset, predictions, output_folder, cfg):
+   
+    args = dict(
+        dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs
+    )
+    if isinstance(dataset, datasets.COCODataset) or isinstance(dataset, datasets.TSVDataset):
+        return coco_evaluation(**args)
+    # elif isinstance(dataset, datasets.VGTSVDataset):
+    #     return vg_evaluation(**args)
+    elif isinstance(dataset, datasets.PascalVOCDataset):
+        return voc_evaluation(**args)
+    elif isinstance(dataset, datasets.CocoDetectionTSV):
+        return od_to_grounding_evaluation(**args)
+    elif isinstance(dataset, datasets.LvisDetection):
+        pass
+    else:
+        dataset_name = dataset.__class__.__name__
+        raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name))
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py b/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7e5d1907bc3fa69ce85a78723479988e532b2e
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py
@@ -0,0 +1,349 @@
+import torch
+import numpy as np
+
+from maskrcnn_benchmark.config import cfg
+from maskrcnn_benchmark.data import transforms as T
+from maskrcnn_benchmark.structures.image_list import to_image_list
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.layers import nms, soft_nms
+
+
+def im_detect_bbox_aug(model, images, device, captions=None, positive_map_label_to_token=None):
+    # Collect detections computed under different transformations
+    boxlists_ts = []
+    for _ in range(len(images)):
+        boxlists_ts.append([])
+
+    def add_preds_t(boxlists_t):
+        for i, boxlist_t in enumerate(boxlists_t):
+            # Resize the boxlist as the first one
+            boxlists_ts[i].append(boxlist_t.resize(images[i].size))
+
+    # Compute detections at different scales
+    if len(cfg.TEST.RANGES)==len(cfg.TEST.SCALES):
+        keep_ranges = cfg.TEST.RANGES
+    else:
+        keep_ranges = [None for _ in cfg.TEST.SCALES]
+
+    for scale, keep_range in zip(cfg.TEST.SCALES, keep_ranges):
+        max_size = cfg.TEST.MAX_SIZE
+        boxlists_scl = im_detect_bbox_scale(
+            model, images, scale, max_size, device,
+            captions=captions,
+            positive_map_label_to_token=positive_map_label_to_token,
+        )
+        if keep_range is not None:
+            boxlists_scl = remove_boxes(boxlists_scl, *keep_range)
+        add_preds_t(boxlists_scl)
+
+        if cfg.TEST.FLIP:
+            boxlists_scl_hf = im_detect_bbox_scale(
+                model, images, scale, max_size, device,
+                captions=captions,
+                positive_map_label_to_token=positive_map_label_to_token,
+                hflip=True
+            )
+            if keep_range is not None:
+                boxlists_scl_hf = remove_boxes(boxlists_scl_hf, *keep_range)
+            add_preds_t(boxlists_scl_hf)
+
+    # Merge boxlists detected by different bbox aug params
+    boxlists = []
+    for i, boxlist_ts in enumerate(boxlists_ts):
+        bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts])
+        scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts])
+        labels = torch.cat([boxlist_t.get_field('labels') for boxlist_t in boxlist_ts])
+        boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode)
+        boxlist.add_field('scores', scores)
+        boxlist.add_field('labels', labels)
+        boxlists.append(boxlist)
+    results = merge_result_from_multi_scales(boxlists)
+    return results
+
+
+def im_detect_bbox(model, images, target_scale, target_max_size, device,
+                   captions=None,
+                   positive_map_label_to_token=None
+                   ):
+    """
+    Performs bbox detection on the original image.
+    """
+    if cfg.INPUT.FORMAT is not '':
+        input_format = cfg.INPUT.FORMAT
+    elif cfg.INPUT.TO_BGR255:
+        input_format = 'bgr255'
+    transform = T.Compose([
+        T.Resize(target_scale, target_max_size),
+        T.ToTensor(),
+        T.Normalize(
+            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format
+        )
+    ])
+    images = [transform(image) for image in images]
+    images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
+    if captions is None:
+        return model(images.to(device))
+    else:
+        return model(images.to(device),
+                     captions=captions,
+                     positive_map=positive_map_label_to_token
+                     )
+
+
+def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device,
+                         captions=None,
+                         positive_map_label_to_token=None
+                         ):
+    """
+    Performs bbox detection on the horizontally flipped image.
+    Function signature is the same as for im_detect_bbox.
+    """
+    if cfg.INPUT.FORMAT is not '':
+        input_format = cfg.INPUT.FORMAT
+    elif cfg.INPUT.TO_BGR255:
+        input_format = 'bgr255'
+    transform = T.Compose([
+        T.Resize(target_scale, target_max_size),
+        T.RandomHorizontalFlip(1.0),
+        T.ToTensor(),
+        T.Normalize(
+            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format
+        )
+    ])
+    images = [transform(image) for image in images]
+    images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
+    if captions is None:
+        boxlists = model(images.to(device))
+    else:
+        boxlists = model(images.to(device),
+                         captions=captions,
+                         positive_map=positive_map_label_to_token
+                         )
+
+    # Invert the detections computed on the flipped image
+    boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists]
+    return boxlists_inv
+
+
+def im_detect_bbox_scale(model, images, target_scale, target_max_size, device,
+                         captions=None,
+                         positive_map_label_to_token=None,
+                         hflip=False):
+    """
+    Computes bbox detections at the given scale.
+    Returns predictions in the scaled image space.
+    """
+    if hflip:
+        boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device,
+                                            captions=captions,
+                                            positive_map_label_to_token=positive_map_label_to_token
+                                            )
+    else:
+        boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device,
+                                      captions=captions,
+                                      positive_map_label_to_token=positive_map_label_to_token
+                                      )
+    return boxlists_scl
+
+
+def remove_boxes(boxlist_ts, min_scale, max_scale):
+    new_boxlist_ts = []
+    for _, boxlist_t in enumerate(boxlist_ts):
+        mode = boxlist_t.mode
+        boxlist_t = boxlist_t.convert("xyxy")
+        boxes = boxlist_t.bbox
+        keep = []
+        for j, box in enumerate(boxes):
+            w = box[2] - box[0] + 1
+            h = box[3] - box[1] + 1
+            if (w * h > min_scale * min_scale) and (w * h < max_scale * max_scale):
+                keep.append(j)
+        new_boxlist_ts.append(boxlist_t[keep].convert(mode))
+    return new_boxlist_ts
+
+
+def merge_result_from_multi_scales(boxlists):
+    num_images = len(boxlists)
+    results = []
+    for i in range(num_images):
+        scores = boxlists[i].get_field("scores")
+        labels = boxlists[i].get_field("labels")
+        boxes = boxlists[i].bbox
+        boxlist = boxlists[i]
+        result = []
+        # test on classes
+        if len(cfg.TEST.SELECT_CLASSES):
+            class_list = cfg.TEST.SELECT_CLASSES
+        else:
+            class_list = range(1, cfg.TEST.NUM_CLASSES)
+        for j in class_list:
+            inds = (labels == j).nonzero().view(-1)
+
+            scores_j = scores[inds]
+            boxes_j = boxes[inds, :].view(-1, 4)
+            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
+            boxlist_for_class.add_field("scores", scores_j)
+            boxlist_for_class = boxlist_nms(boxlist_for_class, cfg.TEST.TH, score_field="scores", nms_type=cfg.TEST.SPECIAL_NMS)
+            num_labels = len(boxlist_for_class)
+            boxlist_for_class.add_field("labels", torch.full((num_labels,), j, dtype=torch.int64, device=scores.device))
+            result.append(boxlist_for_class)
+
+        result = cat_boxlist(result)
+        number_of_detections = len(result)
+
+        # Limit to max_per_image detections **over all classes**
+        if number_of_detections > cfg.TEST.PRE_NMS_TOP_N > 0:
+            cls_scores = result.get_field("scores")
+            image_thresh, _ = torch.kthvalue(
+                cls_scores.cpu(),
+                number_of_detections - cfg.TEST.PRE_NMS_TOP_N + 1
+            )
+            keep = cls_scores >= image_thresh.item()
+            keep = torch.nonzero(keep).squeeze(1)
+            result = result[keep]
+        results.append(result)
+    return results
+
+
+def boxlist_nms(boxlist, thresh, max_proposals=-1, score_field="scores", nms_type='nms'):
+    if thresh <= 0:
+        return boxlist
+    mode = boxlist.mode
+    boxlist = boxlist.convert("xyxy")
+    boxes = boxlist.bbox
+    score = boxlist.get_field(score_field)
+
+    if nms_type == 'vote':
+        boxes_vote, scores_vote = bbox_vote(boxes, score, thresh)
+        if len(boxes_vote) > 0:
+            boxlist.bbox = boxes_vote
+            boxlist.extra_fields['scores'] = scores_vote
+    elif nms_type == 'soft-vote':
+        boxes_vote, scores_vote = soft_bbox_vote(boxes, score, thresh)
+        if len(boxes_vote) > 0:
+            boxlist.bbox = boxes_vote
+            boxlist.extra_fields['scores'] = scores_vote
+    elif nms_type == 'soft-nms':
+        keep, new_score = soft_nms(boxes.cpu(), score.cpu(), thresh, 0.95)
+        if max_proposals > 0:
+            keep = keep[: max_proposals]
+        boxlist = boxlist[keep]
+        boxlist.extra_fields['scores'] = new_score
+    else:
+        keep = nms(boxes, score, thresh)
+        if max_proposals > 0:
+            keep = keep[: max_proposals]
+        boxlist = boxlist[keep]
+    return boxlist.convert(mode)
+
+
+def bbox_vote(boxes, scores, vote_thresh):
+    boxes = boxes.cpu().numpy()
+    scores = scores.cpu().numpy().reshape(-1, 1)
+    det = np.concatenate((boxes, scores), axis=1)
+    if det.shape[0] <= 1:
+        return np.zeros((0, 5)), np.zeros((0, 1))
+    order = det[:, 4].ravel().argsort()[::-1]
+    det = det[order, :]
+    dets = []
+    while det.shape[0] > 0:
+        # IOU
+        area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
+        xx1 = np.maximum(det[0, 0], det[:, 0])
+        yy1 = np.maximum(det[0, 1], det[:, 1])
+        xx2 = np.minimum(det[0, 2], det[:, 2])
+        yy2 = np.minimum(det[0, 3], det[:, 3])
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        o = inter / (area[0] + area[:] - inter)
+
+        # get needed merge det and delete these  det
+        merge_index = np.where(o >= vote_thresh)[0]
+        det_accu = det[merge_index, :]
+        det = np.delete(det, merge_index, 0)
+
+        if merge_index.shape[0] <= 1:
+            try:
+                dets = np.row_stack((dets, det_accu))
+            except:
+                dets = det_accu
+            continue
+        else:
+            det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
+            max_score = np.max(det_accu[:, 4])
+            det_accu_sum = np.zeros((1, 5))
+            det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
+            det_accu_sum[:, 4] = max_score
+            try:
+                dets = np.row_stack((dets, det_accu_sum))
+            except:
+                dets = det_accu_sum
+
+    boxes = torch.from_numpy(dets[:, :4]).float().cuda()
+    scores = torch.from_numpy(dets[:, 4]).float().cuda()
+
+    return boxes, scores
+
+
+def soft_bbox_vote(boxes, scores, vote_thresh):
+    boxes = boxes.cpu().numpy()
+    scores = scores.cpu().numpy().reshape(-1, 1)
+    det = np.concatenate((boxes, scores), axis=1)
+    if det.shape[0] <= 1:
+        return np.zeros((0, 5)), np.zeros((0, 1))
+    order = det[:, 4].ravel().argsort()[::-1]
+    det = det[order, :]
+    dets = []
+    while det.shape[0] > 0:
+        # IOU
+        area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
+        xx1 = np.maximum(det[0, 0], det[:, 0])
+        yy1 = np.maximum(det[0, 1], det[:, 1])
+        xx2 = np.minimum(det[0, 2], det[:, 2])
+        yy2 = np.minimum(det[0, 3], det[:, 3])
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        o = inter / (area[0] + area[:] - inter)
+
+        # get needed merge det and delete these  det
+        merge_index = np.where(o >= vote_thresh)[0]
+        det_accu = det[merge_index, :]
+        det_accu_iou = o[merge_index]
+        det = np.delete(det, merge_index, 0)
+
+        if merge_index.shape[0] <= 1:
+            try:
+                dets = np.row_stack((dets, det_accu))
+            except:
+                dets = det_accu
+            continue
+        else:
+            soft_det_accu = det_accu.copy()
+            soft_det_accu[:, 4] = soft_det_accu[:, 4] * (1 - det_accu_iou)
+            soft_index = np.where(soft_det_accu[:, 4] >= cfg.MODEL.RETINANET.INFERENCE_TH)[0]
+            soft_det_accu = soft_det_accu[soft_index, :]
+
+            det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
+            max_score = np.max(det_accu[:, 4])
+            det_accu_sum = np.zeros((1, 5))
+            det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
+            det_accu_sum[:, 4] = max_score
+
+            if soft_det_accu.shape[0] > 0:
+                det_accu_sum = np.row_stack((det_accu_sum, soft_det_accu))
+
+            try:
+                dets = np.row_stack((dets, det_accu_sum))
+            except:
+                dets = det_accu_sum
+
+    order = dets[:, 4].ravel().argsort()[::-1]
+    dets = dets[order, :]
+
+    boxes = torch.from_numpy(dets[:, :4]).float().cuda()
+    scores = torch.from_numpy(dets[:, 4]).float().cuda()
+
+    return boxes, scores
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a25c9b536e131b4d8bfd8e7ceb24c783d8d97cd
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py
@@ -0,0 +1,21 @@
+from .coco_eval import do_coco_evaluation
+
+
+def coco_evaluation(
+    dataset,
+    predictions,
+    output_folder,
+    box_only=False,
+    iou_types=("bbox",),
+    expected_results=(),
+    expected_results_sigma_tol=4,
+):
+    return do_coco_evaluation(
+        dataset=dataset,
+        predictions=predictions,
+        box_only=box_only,
+        output_folder=output_folder,
+        iou_types=iou_types,
+        expected_results=expected_results,
+        expected_results_sigma_tol=expected_results_sigma_tol,
+    )
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..be79d7429b14c848ca161ccfe434512749c06af8
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py
@@ -0,0 +1,531 @@
+import logging
+import tempfile
+import os
+import torch
+import numpy as np
+import json
+
+from collections import OrderedDict
+from tqdm import tqdm
+
+from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+
+
+def do_coco_evaluation(
+        dataset,
+        predictions,
+        box_only,
+        output_folder,
+        iou_types,
+        expected_results,
+        expected_results_sigma_tol,
+):
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+
+    if box_only:
+        logger.info("Evaluating bbox proposals")
+        if dataset.coco is None and output_folder:
+            json_results = prepare_for_tsv_detection(predictions, dataset)
+            with open(os.path.join(output_folder, "box_proposals.json"), "w") as f:
+                json.dump(json_results, f)
+            return None
+        areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+        res = COCOResults("box_proposal")
+        for limit in [100, 1000]:
+            for area, suffix in areas.items():
+                stats = evaluate_box_proposals(
+                    predictions, dataset, area=area, limit=limit
+                )
+                key = "AR{}@{:d}".format(suffix, limit)
+                res.results["box_proposal"][key] = stats["ar"].item()
+        logger.info(res)
+        check_expected_results(res, expected_results, expected_results_sigma_tol)
+        if output_folder:
+            torch.save(res, os.path.join(output_folder, "box_proposals.pth"))
+        return res, predictions
+    logger.info("Preparing results for COCO format")
+    coco_results = {}
+    if "bbox" in iou_types:
+        logger.info("Preparing bbox results")
+        if dataset.coco is None:
+            coco_results["bbox"] = prepare_for_tsv_detection(predictions, dataset)
+        else:
+            coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset)
+    if "segm" in iou_types:
+        logger.info("Preparing segm results")
+        coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset)
+    if 'keypoints' in iou_types:
+        logger.info('Preparing keypoints results')
+        coco_results['keypoints'] = prepare_for_coco_keypoint(predictions, dataset)
+
+    results = COCOResults(*iou_types)
+    logger.info("Evaluating predictions")
+    for iou_type in iou_types:
+        with tempfile.NamedTemporaryFile() as f:
+            file_path = f.name
+            if output_folder:
+                file_path = os.path.join(output_folder, iou_type + ".json")
+            if dataset.coco:
+                res = evaluate_predictions_on_coco(
+                    dataset.coco, coco_results[iou_type], file_path, iou_type
+                )
+                results.update(res)
+            elif output_folder:
+                with open(file_path, "w") as f:
+                    json.dump(coco_results[iou_type], f)
+
+    logger.info(results)
+    check_expected_results(results, expected_results, expected_results_sigma_tol)
+    if output_folder:
+        torch.save(results, os.path.join(output_folder, "coco_results.pth"))
+    return results, coco_results
+
+
+def prepare_for_tsv_detection(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    proposal_results = []
+    image_list = []
+    for im_id, prediction in enumerate(predictions):
+        image_info = dataset.get_img_info(im_id)
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_id = image_info["id"]
+        image_width = image_info["width"]
+        image_height = image_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert("xywh")
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+        if prediction.has_field("centers"):
+            centers = prediction.get_field("centers")
+        else:
+            centers = None
+
+        for k, box in enumerate(boxes):
+            proposal = {
+                "image_id": image_id,
+                "category_id": labels[k],
+                "bbox": box,
+                "score": scores[k],
+                "area": image_width * image_height,
+                "iscrowd": 0,
+            }
+            if centers is not None:
+                proposal.update(center=centers[k].tolist())
+            proposal_results.append(proposal)
+
+        image_list.append(image_info)
+
+        # categories = [{'supercategory': 'proposal', 'id': 0, 'name': 'proposal'}]
+    return dict(images=image_list, annotations=proposal_results)
+
+
+def prepare_for_coco_detection(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert("xywh")
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+
+        for k, box in enumerate(boxes):
+            if labels[k] in dataset.contiguous_category_id_to_json_id:
+                coco_results.append(
+                    {
+                        "image_id": original_id,
+                        "category_id": dataset.contiguous_category_id_to_json_id[labels[k]],
+                        "bbox": box,
+                        "score": scores[k],
+                    })
+
+    return coco_results
+
+
+def prepare_for_coco_segmentation(predictions, dataset):
+    import pycocotools.mask as mask_util
+    import numpy as np
+
+    masker = Masker(threshold=0.5, padding=1)
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in tqdm(enumerate(predictions)):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+        masks = prediction.get_field("mask")
+        # t = time.time()
+        # Masker is necessary only if masks haven't been already resized.
+        if list(masks.shape[-2:]) != [image_height, image_width]:
+            masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
+            masks = masks[0]
+        # logger.info('Time mask: {}'.format(time.time() - t))
+        # prediction = prediction.convert('xywh')
+
+        # boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+
+        # rles = prediction.get_field('mask')
+
+        rles = [
+            mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
+            for mask in masks
+        ]
+        for rle in rles:
+            rle["counts"] = rle["counts"].decode("utf-8")
+
+        mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
+
+        coco_results.extend(
+            [
+                {
+                    "image_id": original_id,
+                    "category_id": mapped_labels[k],
+                    "segmentation": rle,
+                    "score": scores[k],
+                }
+                for k, rle in enumerate(rles)
+            ]
+        )
+    return coco_results
+
+
+def prepare_for_coco_keypoint(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction.bbox) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]['width']
+        image_height = dataset.coco.imgs[original_id]['height']
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert('xywh')
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field('scores').tolist()
+        labels = prediction.get_field('labels').tolist()
+        keypoints = prediction.get_field('keypoints')
+        keypoints = keypoints.resize((image_width, image_height))
+        keypoints = keypoints.to_coco_format()
+
+        mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
+
+        coco_results.extend([{
+            'image_id': original_id,
+            'category_id': mapped_labels[k],
+            'keypoints': keypoint,
+            'score': scores[k]} for k, keypoint in enumerate(keypoints)])
+    return coco_results
+
+
+# inspired from Detectron
+def evaluate_box_proposals(
+        predictions, dataset, thresholds=None, area="all", limit=None
+):
+    """Evaluate detection proposal recall metrics. This function is a much
+    faster alternative to the official COCO API recall evaluation code. However,
+    it produces slightly different results.
+    """
+    # Record max overlap value for each gt box
+    # Return vector of overlap values
+    areas = {
+        "all": 0,
+        "small": 1,
+        "medium": 2,
+        "large": 3,
+        "96-128": 4,
+        "128-256": 5,
+        "256-512": 6,
+        "512-inf": 7,
+    }
+    area_ranges = [
+        [0 ** 2, 1e5 ** 2],  # all
+        [0 ** 2, 32 ** 2],  # small
+        [32 ** 2, 96 ** 2],  # medium
+        [96 ** 2, 1e5 ** 2],  # large
+        [96 ** 2, 128 ** 2],  # 96-128
+        [128 ** 2, 256 ** 2],  # 128-256
+        [256 ** 2, 512 ** 2],  # 256-512
+        [512 ** 2, 1e5 ** 2],
+    ]  # 512-inf
+    assert area in areas, "Unknown area range: {}".format(area)
+    area_range = area_ranges[areas[area]]
+    gt_overlaps = []
+    num_pos = 0
+
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+
+        # sort predictions in descending order
+        # TODO maybe remove this and make it explicit in the documentation
+        if prediction.has_field("objectness"):
+            inds = prediction.get_field("objectness").sort(descending=True)[1]
+        else:
+            inds = prediction.get_field("scores").sort(descending=True)[1]
+        prediction = prediction[inds]
+
+        ann_ids = dataset.coco.getAnnIds(imgIds=original_id)
+        anno = dataset.coco.loadAnns(ann_ids)
+        gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0]
+        gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4)  # guard against no boxes
+        gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert(
+            "xyxy"
+        )
+        gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
+
+        if len(gt_boxes) == 0:
+            continue
+
+        valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+        gt_boxes = gt_boxes[valid_gt_inds]
+
+        num_pos += len(gt_boxes)
+
+        if len(gt_boxes) == 0:
+            continue
+
+        if len(prediction) == 0:
+            continue
+
+        if limit is not None and len(prediction) > limit:
+            prediction = prediction[:limit]
+
+        overlaps = boxlist_iou(prediction, gt_boxes)
+
+        _gt_overlaps = torch.zeros(len(gt_boxes))
+        for j in range(min(len(prediction), len(gt_boxes))):
+            # find which proposal box maximally covers each gt box
+            # and get the iou amount of coverage for each gt box
+            max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+            # find which gt box is 'best' covered (i.e. 'best' = most iou)
+            gt_ovr, gt_ind = max_overlaps.max(dim=0)
+            assert gt_ovr >= 0
+            # find the proposal box that covers the best covered gt box
+            box_ind = argmax_overlaps[gt_ind]
+            # record the iou coverage of this gt box
+            _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+            assert _gt_overlaps[j] == gt_ovr
+            # mark the proposal box and the gt box as used
+            overlaps[box_ind, :] = -1
+            overlaps[:, gt_ind] = -1
+
+        # append recorded iou coverage level
+        gt_overlaps.append(_gt_overlaps)
+
+    if len(gt_overlaps) == 0:
+        return {
+            "ar": torch.zeros(1),
+            "recalls": torch.zeros(1),
+            "thresholds": thresholds,
+            "gt_overlaps": gt_overlaps,
+            "num_pos": num_pos,
+        }
+
+    gt_overlaps = torch.cat(gt_overlaps, dim=0)
+    gt_overlaps, _ = torch.sort(gt_overlaps)
+
+    if thresholds is None:
+        step = 0.05
+        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+    recalls = torch.zeros_like(thresholds)
+    # compute recall for each iou threshold
+    for i, t in enumerate(thresholds):
+        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+    # ar = 2 * np.trapz(recalls, thresholds)
+    ar = recalls.mean()
+    return {
+        "ar": ar,
+        "recalls": recalls,
+        "thresholds": thresholds,
+        "gt_overlaps": gt_overlaps,
+        "num_pos": num_pos,
+    }
+
+
+def evaluate_predictions_on_coco(
+        coco_gt, coco_results, json_result_file, iou_type="bbox"
+):
+    import json
+
+    with open(json_result_file, "w") as f:
+        json.dump(coco_results, f)
+
+    from pycocotools.coco import COCO
+    from pycocotools.cocoeval import COCOeval
+
+    coco_dt = coco_gt.loadRes(str(json_result_file)) if coco_results else COCO()
+
+    # coco_dt = coco_gt.loadRes(coco_results)
+    if iou_type == 'keypoints':
+        coco_gt = filter_valid_keypoints(coco_gt, coco_dt)
+    coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
+    coco_eval.evaluate()
+    coco_eval.accumulate()
+    coco_eval.summarize()
+    if iou_type == 'bbox':
+        summarize_per_category(coco_eval, json_result_file.replace('.json', '.csv'))
+    return coco_eval
+
+
+def summarize_per_category(coco_eval, csv_output=None):
+    '''
+    Compute and display summary metrics for evaluation results.
+    Note this functin can *only* be applied on the default parameter setting
+    '''
+
+    def _summarize(iouThr=None, areaRng='all', maxDets=100):
+        p = coco_eval.params
+        titleStr = 'Average Precision'
+        typeStr = '(AP)'
+        iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
+            if iouThr is None else '{:0.2f}'.format(iouThr)
+        result_str = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ], '. \
+            format(titleStr, typeStr, iouStr, areaRng, maxDets)
+
+        aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+        mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+
+        # dimension of precision: [TxRxKxAxM]
+        s = coco_eval.eval['precision']
+        # IoU
+        if iouThr is not None:
+            t = np.where(iouThr == p.iouThrs)[0]
+            s = s[t]
+        s = s[:, :, :, aind, mind]
+
+        if len(s[s > -1]) == 0:
+            mean_s = -1
+        else:
+            mean_s = np.mean(s[s > -1])
+            # cacluate AP(average precision) for each category
+            num_classes = len(p.catIds)
+            avg_ap = 0.0
+            for i in range(0, num_classes):
+                result_str += '{}, '.format(np.mean(s[:, :, i, :]))
+                avg_ap += np.mean(s[:, :, i, :])
+            result_str += ('{} \n'.format(avg_ap / num_classes))
+        return result_str
+
+    id2name = {}
+    for _, cat in coco_eval.cocoGt.cats.items():
+        id2name[cat['id']] = cat['name']
+    title_str = 'metric, '
+    for cid in coco_eval.params.catIds:
+        title_str += '{}, '.format(id2name[cid])
+    title_str += 'avg \n'
+
+    results = [title_str]
+    results.append(_summarize())
+    results.append(_summarize(iouThr=.5, maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='small', maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='medium', maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='large', maxDets=coco_eval.params.maxDets[2]))
+
+    with open(csv_output, 'w') as f:
+        for result in results:
+            f.writelines(result)
+
+
+def filter_valid_keypoints(coco_gt, coco_dt):
+    kps = coco_dt.anns[1]['keypoints']
+    for id, ann in coco_gt.anns.items():
+        ann['keypoints'][2::3] = [a * b for a, b in zip(ann['keypoints'][2::3], kps[2::3])]
+        ann['num_keypoints'] = sum(ann['keypoints'][2::3])
+    return coco_gt
+
+
+class COCOResults(object):
+    METRICS = {
+        "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+        "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+        "box_proposal": [
+            "AR@100",
+            "ARs@100",
+            "ARm@100",
+            "ARl@100",
+            "AR@1000",
+            "ARs@1000",
+            "ARm@1000",
+            "ARl@1000",
+        ],
+        "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
+    }
+
+    def __init__(self, *iou_types):
+        allowed_types = ("box_proposal", "bbox", "segm", "keypoints")
+        assert all(iou_type in allowed_types for iou_type in iou_types)
+        results = OrderedDict()
+        for iou_type in iou_types:
+            results[iou_type] = OrderedDict(
+                [(metric, -1) for metric in COCOResults.METRICS[iou_type]]
+            )
+        self.results = results
+
+    def update(self, coco_eval):
+        if coco_eval is None:
+            return
+        from pycocotools.cocoeval import COCOeval
+
+        assert isinstance(coco_eval, COCOeval)
+        s = coco_eval.stats
+        iou_type = coco_eval.params.iouType
+        res = self.results[iou_type]
+        metrics = COCOResults.METRICS[iou_type]
+        for idx, metric in enumerate(metrics):
+            res[metric] = s[idx]
+
+    def __repr__(self):
+        # TODO make it pretty
+        return repr(self.results)
+
+
+def check_expected_results(results, expected_results, sigma_tol):
+    if not expected_results:
+        return
+
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    for task, metric, (mean, std) in expected_results:
+        actual_val = results.results[task][metric]
+        lo = mean - sigma_tol * std
+        hi = mean + sigma_tol * std
+        ok = (lo < actual_val) and (actual_val < hi)
+        msg = (
+            "{} > {} sanity check (actual vs. expected): "
+            "{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})"
+        ).format(task, metric, actual_val, mean, std, lo, hi)
+        if not ok:
+            msg = "FAIL: " + msg
+            logger.error(msg)
+        else:
+            msg = "PASS: " + msg
+            logger.info(msg)
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd063073c837183ac09aee7c6bbc4d8ad9dd47ef
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py
@@ -0,0 +1 @@
+from .flickr_eval import FlickrEvaluator
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..394bd59aa88b4b9ca67ca7b5ad18f390befe9b99
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py
@@ -0,0 +1,440 @@
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+import json
+import numpy as np
+import os.path as osp
+import os
+from prettytable import PrettyTable
+
+import xml.etree.ElementTree as ET
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+import maskrcnn_benchmark.utils.mdetr_dist as dist
+#### The following loading utilities are imported from
+#### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py
+# Changelog:
+#    - Added typing information
+#    - Completed docstrings
+
+def get_sentence_data(filename) -> List[Dict[str, Any]]:
+    """
+    Parses a sentence file from the Flickr30K Entities dataset
+
+    input:
+      filename - full file path to the sentence file to parse
+
+    output:
+      a list of dictionaries for each sentence with the following fields:
+          sentence - the original sentence
+          phrases - a list of dictionaries for each phrase with the
+                    following fields:
+                      phrase - the text of the annotated phrase
+                      first_word_index - the position of the first word of
+                                         the phrase in the sentence
+                      phrase_id - an identifier for this phrase
+                      phrase_type - a list of the coarse categories this
+                                    phrase belongs to
+
+    """
+    with open(filename, "r") as f:
+        sentences = f.read().split("\n")
+
+    annotations = []
+    for sentence in sentences:
+        if not sentence:
+            continue
+
+        first_word = []
+        phrases = []
+        phrase_id = []
+        phrase_type = []
+        words = []
+        current_phrase = []
+        add_to_phrase = False
+        for token in sentence.split():
+            if add_to_phrase:
+                if token[-1] == "]":
+                    add_to_phrase = False
+                    token = token[:-1]
+                    current_phrase.append(token)
+                    phrases.append(" ".join(current_phrase))
+                    current_phrase = []
+                else:
+                    current_phrase.append(token)
+
+                words.append(token)
+            else:
+                if token[0] == "[":
+                    add_to_phrase = True
+                    first_word.append(len(words))
+                    parts = token.split("/")
+                    phrase_id.append(parts[1][3:])
+                    phrase_type.append(parts[2:])
+                else:
+                    words.append(token)
+
+        sentence_data = {"sentence": " ".join(words), "phrases": []}
+        for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
+            sentence_data["phrases"].append(
+                {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type}
+            )
+
+        annotations.append(sentence_data)
+
+    return annotations
+
+
+def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]:
+    """
+    Parses the xml files in the Flickr30K Entities dataset
+
+    input:
+      filename - full file path to the annotations file to parse
+
+    output:
+      dictionary with the following fields:
+          scene - list of identifiers which were annotated as
+                  pertaining to the whole scene
+          nobox - list of identifiers which were annotated as
+                  not being visible in the image
+          boxes - a dictionary where the fields are identifiers
+                  and the values are its list of boxes in the
+                  [xmin ymin xmax ymax] format
+          height - int representing the height of the image
+          width - int representing the width of the image
+          depth - int representing the depth of the image
+    """
+    tree = ET.parse(filename)
+    root = tree.getroot()
+    size_container = root.findall("size")[0]
+    anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {}
+    all_boxes: Dict[str, List[List[int]]] = {}
+    all_noboxes: List[str] = []
+    all_scenes: List[str] = []
+    for size_element in size_container:
+        assert size_element.text
+        anno_info[size_element.tag] = int(size_element.text)
+
+    for object_container in root.findall("object"):
+        for names in object_container.findall("name"):
+            box_id = names.text
+            assert box_id
+            box_container = object_container.findall("bndbox")
+            if len(box_container) > 0:
+                if box_id not in all_boxes:
+                    all_boxes[box_id] = []
+                xmin = int(box_container[0].findall("xmin")[0].text)
+                ymin = int(box_container[0].findall("ymin")[0].text)
+                xmax = int(box_container[0].findall("xmax")[0].text)
+                ymax = int(box_container[0].findall("ymax")[0].text)
+                all_boxes[box_id].append([xmin, ymin, xmax, ymax])
+            else:
+                nobndbox = int(object_container.findall("nobndbox")[0].text)
+                if nobndbox > 0:
+                    all_noboxes.append(box_id)
+
+                scene = int(object_container.findall("scene")[0].text)
+                if scene > 0:
+                    all_scenes.append(box_id)
+    anno_info["boxes"] = all_boxes
+    anno_info["nobox"] = all_noboxes
+    anno_info["scene"] = all_scenes
+
+    return anno_info
+
+
+#### END of import from flickr30k_entities
+
+
+#### Bounding box utilities imported from torchvision and converted to numpy
+def box_area(boxes: np.array) -> np.array:
+    """
+    Computes the area of a set of bounding boxes, which are specified by its
+    (x1, y1, x2, y2) coordinates.
+
+    Args:
+        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
+            are expected to be in (x1, y1, x2, y2) format with
+            ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+    Returns:
+        area (Tensor[N]): area for each box
+    """
+    assert boxes.ndim == 2 and boxes.shape[-1] == 4
+    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
+# with slight modifications
+def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]:
+    area1 = box_area(boxes1)
+    area2 = box_area(boxes2)
+
+    lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
+    rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
+
+    wh = (rb - lt).clip(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    union = area1[:, None] + area2 - inter
+
+    return inter, union
+
+
+def box_iou(boxes1: np.array, boxes2: np.array) -> np.array:
+    """
+    Return intersection-over-union (Jaccard index) of boxes.
+
+    Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
+    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+    Args:
+        boxes1 (Tensor[N, 4])
+        boxes2 (Tensor[M, 4])
+
+    Returns:
+        iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
+    """
+    inter, union = _box_inter_union(boxes1, boxes2)
+    iou = inter / union
+    return iou
+
+
+#### End of import of box utilities
+
+def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]:
+    """
+    Return the boxes corresponding to the smallest enclosing box containing all the provided boxes
+    The boxes are expected in [x1, y1, x2, y2] format
+    """
+    if len(boxes) == 1:
+        return boxes
+
+    np_boxes = np.asarray(boxes)
+
+    return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]]
+
+
+class RecallTracker:
+    """ Utility class to track recall@k for various k, split by categories"""
+
+    def __init__(self, topk: Sequence[int]):
+        """
+        Parameters:
+           - topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...)
+        """
+
+        self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}
+        self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}
+
+    def add_positive(self, k: int, category: str):
+        """Log a positive hit @k for given category"""
+        if k not in self.total_byk_bycat:
+            raise RuntimeError(f"{k} is not a valid recall threshold")
+        self.total_byk_bycat[k][category] += 1
+        self.positives_byk_bycat[k][category] += 1
+
+    def add_negative(self, k: int, category: str):
+        """Log a negative hit @k for given category"""
+        if k not in self.total_byk_bycat:
+            raise RuntimeError(f"{k} is not a valid recall threshold")
+        self.total_byk_bycat[k][category] += 1
+
+    def report(self) -> Dict[int, Dict[str, float]]:
+        """Return a condensed report of the results as a dict of dict.
+        report[k][cat] is the recall@k for the given category
+        """
+        report: Dict[int, Dict[str, float]] = {}
+        for k in self.total_byk_bycat:
+            assert k in self.positives_byk_bycat
+            report[k] = {
+                cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k]
+            }
+        return report
+
+
+class Flickr30kEntitiesRecallEvaluator:
+    def __init__(
+            self,
+            flickr_path: str,
+            subset: str = "test",
+            topk: Sequence[int] = (1, 5, 10, -1),
+            iou_thresh: float = 0.5,
+            merge_boxes: bool = False,
+            verbose: bool = True,
+    ):
+        assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}"
+
+        self.topk = topk
+        self.iou_thresh = iou_thresh
+
+        flickr_path = Path(flickr_path)
+
+        # We load the image ids corresponding to the current subset
+        with open(flickr_path / f"{subset}.txt") as file_d:
+            self.img_ids = [line.strip() for line in file_d]
+
+        if verbose:
+            print(f"Flickr subset contains {len(self.img_ids)} images")
+
+        # Read the box annotations for all the images
+        self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {}
+
+        if verbose:
+            print("Loading annotations...")
+
+        for img_id in self.img_ids:
+            anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"]
+            if merge_boxes:
+                merged = {}
+                for phrase_id, boxes in anno_info.items():
+                    merged[phrase_id] = _merge_boxes(boxes)
+                anno_info = merged
+            self.imgid2boxes[img_id] = anno_info
+
+        # Read the sentences annotations
+        self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {}
+
+        if verbose:
+            print("Loading annotations...")
+
+        self.all_ids: List[str] = []
+        tot_phrases = 0
+        for img_id in self.img_ids:
+            sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt")
+            self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))]
+
+            # Some phrases don't have boxes, we filter them.
+            for sent_id, sentence in enumerate(sentence_info):
+                phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]]
+                if len(phrases) > 0:
+                    self.imgid2sentences[img_id][sent_id] = phrases
+                tot_phrases += len(phrases)
+
+            self.all_ids += [
+                f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None
+            ]
+
+        if verbose:
+            print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate")
+
+    def evaluate(self, predictions: List[Dict]):
+        evaluated_ids = set()
+
+        recall_tracker = RecallTracker(self.topk)
+
+        for pred in predictions:
+            cur_id = f"{pred['image_id']}_{pred['sentence_id']}"
+            if cur_id in evaluated_ids:
+                print(
+                    "Warning, multiple predictions found for sentence"
+                    f"{pred['sentence_id']} in image {pred['image_id']}"
+                )
+                continue
+
+            # Skip the sentences with no valid phrase
+            if cur_id not in self.all_ids:
+                if len(pred["boxes"]) != 0:
+                    print(
+                        f"Warning, in image {pred['image_id']} we were not expecting predictions "
+                        f"for sentence {pred['sentence_id']}. Ignoring them."
+                    )
+                continue
+
+            evaluated_ids.add(cur_id)
+
+            pred_boxes = pred["boxes"]
+            if str(pred["image_id"]) not in self.imgid2sentences:
+                raise RuntimeError(f"Unknown image id {pred['image_id']}")
+            if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]):
+                raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}")
+            target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]
+
+            phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]
+            if len(pred_boxes) != len(phrases):
+                raise RuntimeError(
+                    f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} "
+                    f"for sentence {pred['sentence_id']} in image {pred['image_id']}"
+                )
+
+            for cur_boxes, phrase in zip(pred_boxes, phrases):
+                target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]]
+
+                ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes))
+                for k in self.topk:
+                    maxi = 0
+                    if k == -1:
+                        maxi = ious.max()
+                    else:
+                        assert k > 0
+                        maxi = ious[:k].max()
+                    if maxi >= self.iou_thresh:
+                        recall_tracker.add_positive(k, "all")
+                        for phrase_type in phrase["phrase_type"]:
+                            recall_tracker.add_positive(k, phrase_type)
+                    else:
+                        recall_tracker.add_negative(k, "all")
+                        for phrase_type in phrase["phrase_type"]:
+                            recall_tracker.add_negative(k, phrase_type)
+
+        if len(evaluated_ids) != len(self.all_ids):
+            print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:")
+            un_processed = set(self.all_ids) - evaluated_ids
+            for missing in un_processed:
+                img_id, sent_id = missing.split("_")
+                print(f"\t sentence {sent_id} in image {img_id}")
+            raise RuntimeError("Missing predictions")
+
+        return recall_tracker.report()
+
+
+class FlickrEvaluator(object):
+    def __init__(
+            self,
+            flickr_path,
+            subset,
+            top_k=(1, 5, 10, -1),
+            iou_thresh=0.5,
+            merge_boxes=False,
+    ):
+        assert isinstance(top_k, (list, tuple))
+
+        self.evaluator = Flickr30kEntitiesRecallEvaluator(
+            flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False
+        )
+        self.predictions = []
+        self.results = None
+
+    def accumulate(self):
+        pass
+
+    def update(self, predictions):
+        self.predictions += predictions
+
+    def synchronize_between_processes(self):
+        all_predictions = dist.all_gather(self.predictions)
+        self.predictions = sum(all_predictions, [])
+
+    def summarize(self):
+        if dist.is_main_process():
+            self.results = self.evaluator.evaluate(self.predictions)
+            table = PrettyTable()
+            all_cat = sorted(list(self.results.values())[0].keys())
+            table.field_names = ["Recall@k"] + all_cat
+
+            score = {}
+            for k, v in self.results.items():
+                cur_results = [v[cat] for cat in all_cat]
+                header = "Upper_bound" if k == -1 else f"Recall@{k}"
+
+                for cat in all_cat:
+                    score[f"{header}_{cat}"] = v[cat]
+                table.add_row([header] + cur_results)
+
+            print(table)
+
+            return score
+
+        return None, None
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..11332d93c1b85fc3df3b6a2480cb1be0e610bac4
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py
@@ -0,0 +1,10 @@
+path = "DATASET/coco/annotations/lvis_v1_minival.json"
+import json
+with open(path) as f:
+    all = json.load(f)
+
+for i in all["images"]:
+    i["file_name"] = "/".join(i["coco_url"].split("/")[-2:])
+
+with open("DATASET/coco/annotations/lvis_v1_minival_inserted_image_name.json", "w") as f:
+    json.dump(all, f)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9288c85562667527e6f41d97f3201c6b71a305
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py
@@ -0,0 +1,207 @@
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import json
+import os
+import time
+from collections import defaultdict
+
+import pycocotools.mask as mask_utils
+import torchvision
+from PIL import Image
+
+
+
+def _isArrayLike(obj):
+    return hasattr(obj, "__iter__") and hasattr(obj, "__len__")
+
+
+class LVIS:
+    def __init__(self, annotation_path=None):
+        """Class for reading and visualizing annotations.
+        Args:
+            annotation_path (str): location of annotation file
+        """
+        self.anns = {}
+        self.cats = {}
+        self.imgs = {}
+        self.img_ann_map = defaultdict(list)
+        self.cat_img_map = defaultdict(list)
+        self.dataset = {}
+
+        if annotation_path is not None:
+            print("Loading annotations.")
+
+            tic = time.time()
+            self.dataset = self._load_json(annotation_path)
+            print("Done (t={:0.2f}s)".format(time.time() - tic))
+
+            assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset))
+            self._create_index()
+
+    def _load_json(self, path):
+        with open(path, "r") as f:
+            return json.load(f)
+
+    def _create_index(self):
+        print("Creating index.")
+
+        self.img_ann_map = defaultdict(list)
+        self.cat_img_map = defaultdict(list)
+
+        self.anns = {}
+        self.cats = {}
+        self.imgs = {}
+
+        for ann in self.dataset["annotations"]:
+            self.img_ann_map[ann["image_id"]].append(ann)
+            self.anns[ann["id"]] = ann
+
+        for img in self.dataset["images"]:
+            self.imgs[img["id"]] = img
+
+        for cat in self.dataset["categories"]:
+            self.cats[cat["id"]] = cat
+
+        for ann in self.dataset["annotations"]:
+            self.cat_img_map[ann["category_id"]].append(ann["image_id"])
+
+        print("Index created.")
+
+    def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None):
+        """Get ann ids that satisfy given filter conditions.
+        Args:
+            img_ids (int array): get anns for given imgs
+            cat_ids (int array): get anns for given cats
+            area_rng (float array): get anns for a given area range. e.g [0, inf]
+        Returns:
+            ids (int array): integer array of ann ids
+        """
+        if img_ids is not None:
+            img_ids = img_ids if _isArrayLike(img_ids) else [img_ids]
+        if cat_ids is not None:
+            cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids]
+        anns = []
+        if img_ids is not None:
+            for img_id in img_ids:
+                anns.extend(self.img_ann_map[img_id])
+        else:
+            anns = self.dataset["annotations"]
+
+        # return early if no more filtering required
+        if cat_ids is None and area_rng is None:
+            return [_ann["id"] for _ann in anns]
+
+        cat_ids = set(cat_ids)
+
+        if area_rng is None:
+            area_rng = [0, float("inf")]
+
+        ann_ids = [
+            _ann["id"]
+            for _ann in anns
+            if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1]
+        ]
+        return ann_ids
+
+    def get_cat_ids(self):
+        """Get all category ids.
+        Returns:
+            ids (int array): integer array of category ids
+        """
+        return list(self.cats.keys())
+
+    def get_img_ids(self):
+        """Get all img ids.
+        Returns:
+            ids (int array): integer array of image ids
+        """
+        return list(self.imgs.keys())
+
+    def _load_helper(self, _dict, ids):
+        if ids is None:
+            return list(_dict.values())
+        elif _isArrayLike(ids):
+            return [_dict[id] for id in ids]
+        else:
+            return [_dict[ids]]
+
+    def load_anns(self, ids=None):
+        """Load anns with the specified ids. If ids=None load all anns.
+        Args:
+            ids (int array): integer array of annotation ids
+        Returns:
+            anns (dict array) : loaded annotation objects
+        """
+        return self._load_helper(self.anns, ids)
+
+    def load_cats(self, ids):
+        """Load categories with the specified ids. If ids=None load all
+        categories.
+        Args:
+            ids (int array): integer array of category ids
+        Returns:
+            cats (dict array) : loaded category dicts
+        """
+        return self._load_helper(self.cats, ids)
+
+    def load_imgs(self, ids):
+        """Load categories with the specified ids. If ids=None load all images.
+        Args:
+            ids (int array): integer array of image ids
+        Returns:
+            imgs (dict array) : loaded image dicts
+        """
+        return self._load_helper(self.imgs, ids)
+
+    def download(self, save_dir, img_ids=None):
+        """Download images from mscoco.org server.
+        Args:
+            save_dir (str): dir to save downloaded images
+            img_ids (int array): img ids of images to download
+        """
+        imgs = self.load_imgs(img_ids)
+
+        if not os.path.exists(save_dir):
+            os.makedirs(save_dir)
+
+        for img in imgs:
+            file_name = os.path.join(save_dir, img["file_name"])
+            if not os.path.exists(file_name):
+                from urllib.request import urlretrieve
+
+                urlretrieve(img["coco_url"], file_name)
+
+    def ann_to_rle(self, ann):
+        """Convert annotation which can be polygons, uncompressed RLE to RLE.
+        Args:
+            ann (dict) : annotation object
+        Returns:
+            ann (rle)
+        """
+        img_data = self.imgs[ann["image_id"]]
+        h, w = img_data["height"], img_data["width"]
+        segm = ann["segmentation"]
+        if isinstance(segm, list):
+            # polygon -- a single object might consist of multiple parts
+            # we merge all parts into one mask rle code
+            rles = mask_utils.frPyObjects(segm, h, w)
+            rle = mask_utils.merge(rles)
+        elif isinstance(segm["counts"], list):
+            # uncompressed RLE
+            rle = mask_utils.frPyObjects(segm, h, w)
+        else:
+            # rle
+            rle = ann["segmentation"]
+        return rle
+
+    def ann_to_mask(self, ann):
+        """Convert annotation which can be polygons, uncompressed RLE, or RLE
+        to binary mask.
+        Args:
+            ann (dict) : annotation object
+        Returns:
+            binary mask (numpy 2D array)
+        """
+        rle = self.ann_to_rle(ann)
+        return mask_utils.decode(rle)
+
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eeca5d2b4cb68bcda1dbd96ae25715ae2deb120
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py
@@ -0,0 +1,998 @@
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import copy
+import datetime
+import json
+import os
+from collections import OrderedDict, defaultdict
+
+import numpy as np
+import pycocotools.mask as mask_util
+import torch
+import torch._six
+
+import maskrcnn_benchmark.utils.mdetr_dist  as dist
+
+from maskrcnn_benchmark.utils.mdetr_dist import all_gather
+
+
+from .lvis import LVIS
+
+def merge(img_ids, eval_imgs):
+    all_img_ids = all_gather(img_ids)
+    all_eval_imgs = all_gather(eval_imgs)
+
+    merged_img_ids = []
+    for p in all_img_ids:
+        merged_img_ids.extend(p)
+
+    merged_eval_imgs = []
+    for p in all_eval_imgs:
+        merged_eval_imgs.append(p)
+
+    merged_img_ids = np.array(merged_img_ids)
+    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
+
+    # keep only unique (and in sorted order) images
+    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+    merged_eval_imgs = merged_eval_imgs[..., idx]
+
+    return merged_img_ids, merged_eval_imgs
+
+
+#################################################################
+# From LVIS, with following changes:
+#     * fixed LVISEval constructor to accept empty dt
+#     * Removed logger
+#     * LVIS results supports numpy inputs
+#################################################################
+
+
+class Params:
+    def __init__(self, iou_type):
+        """Params for LVIS evaluation API."""
+        self.img_ids = []
+        self.cat_ids = []
+        # np.arange causes trouble.  the data point on arange is slightly
+        # larger than the true value
+        self.iou_thrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
+        self.rec_thrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
+        self.max_dets = 300
+        self.area_rng = [
+            [0 ** 2, 1e5 ** 2],
+            [0 ** 2, 32 ** 2],
+            [32 ** 2, 96 ** 2],
+            [96 ** 2, 1e5 ** 2],
+        ]
+        self.area_rng_lbl = ["all", "small", "medium", "large"]
+        self.use_cats = 1
+        # We bin categories in three bins based how many images of the training
+        # set the category is present in.
+        # r: Rare    :  < 10
+        # c: Common  : >= 10 and < 100
+        # f: Frequent: >= 100
+        self.img_count_lbl = ["r", "c", "f"]
+        self.iou_type = iou_type
+
+
+class LVISResults(LVIS):
+    def __init__(self, lvis_gt, results, max_dets=300):
+        """Constructor for LVIS results.
+        Args:
+            lvis_gt (LVIS class instance, or str containing path of
+            annotation file)
+            results (str containing path of result file or a list of dicts)
+            max_dets (int):  max number of detections per image. The official
+            value of max_dets for LVIS is 300.
+        """
+        super(LVISResults, self).__init__()
+        assert isinstance(lvis_gt, LVIS)
+        self.dataset["images"] = [img for img in lvis_gt.dataset["images"]]
+
+        if isinstance(results, str):
+            result_anns = self._load_json(results)
+        elif type(results) == np.ndarray:
+            result_anns = self.loadNumpyAnnotations(results)
+        else:
+            result_anns = results
+
+        if max_dets >= 0:
+            result_anns = self.limit_dets_per_image(result_anns, max_dets)
+
+        if len(result_anns) > 0 and "bbox" in result_anns[0]:
+            self.dataset["categories"] = copy.deepcopy(lvis_gt.dataset["categories"])
+            for id, ann in enumerate(result_anns):
+                x1, y1, w, h = ann["bbox"]
+                x2 = x1 + w
+                y2 = y1 + h
+
+                if "segmentation" not in ann:
+                    ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
+
+                ann["area"] = w * h
+                ann["id"] = id + 1
+
+        elif len(result_anns) > 0 and "segmentation" in result_anns[0]:
+            self.dataset["categories"] = copy.deepcopy(lvis_gt.dataset["categories"])
+            for id, ann in enumerate(result_anns):
+                # Only support compressed RLE format as segmentation results
+                ann["area"] = mask_util.area(ann["segmentation"])
+
+                if "bbox" not in ann:
+                    ann["bbox"] = mask_util.toBbox(ann["segmentation"])
+
+                ann["id"] = id + 1
+
+        self.dataset["annotations"] = result_anns
+        self._create_index()
+
+        # #FIXME: disabling this check for now
+        # img_ids_in_result = [ann["image_id"] for ann in result_anns]
+
+        # assert set(img_ids_in_result) == (
+        #     set(img_ids_in_result) & set(self.get_img_ids())
+        # ), "Results do not correspond to current LVIS set."
+
+    def limit_dets_per_image(self, anns, max_dets):
+        img_ann = defaultdict(list)
+        for ann in anns:
+            img_ann[ann["image_id"]].append(ann)
+
+        for img_id, _anns in img_ann.items():
+            if len(_anns) <= max_dets:
+                continue
+            _anns = sorted(_anns, key=lambda ann: ann["score"], reverse=True)
+            img_ann[img_id] = _anns[:max_dets]
+
+        return [ann for anns in img_ann.values() for ann in anns]
+
+    def get_top_results(self, img_id, score_thrs):
+        ann_ids = self.get_ann_ids(img_ids=[img_id])
+        anns = self.load_anns(ann_ids)
+        return list(filter(lambda ann: ann["score"] > score_thrs, anns))
+
+
+class LVISEval:
+    def __init__(self, lvis_gt, lvis_dt=None, iou_type="segm"):
+        """Constructor for LVISEval.
+        Args:
+            lvis_gt (LVIS class instance, or str containing path of annotation file)
+            lvis_dt (LVISResult class instance, or str containing path of result file,
+            or list of dict)
+            iou_type (str): segm or bbox evaluation
+        """
+
+        if iou_type not in ["bbox", "segm"]:
+            raise ValueError("iou_type: {} is not supported.".format(iou_type))
+
+        if isinstance(lvis_gt, LVIS):
+            self.lvis_gt = lvis_gt
+        elif isinstance(lvis_gt, str):
+            self.lvis_gt = LVIS(lvis_gt)
+        else:
+            raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))
+
+        if isinstance(lvis_dt, LVISResults):
+            self.lvis_dt = lvis_dt
+        elif isinstance(lvis_dt, (str, list)):
+            self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt)
+        elif lvis_dt is not None:
+            raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))
+
+        # per-image per-category evaluation results
+        self.eval_imgs = defaultdict(list)
+        self.eval = {}  # accumulated evaluation results
+        self._gts = defaultdict(list)  # gt for evaluation
+        self._dts = defaultdict(list)  # dt for evaluation
+        self.params = Params(iou_type=iou_type)  # parameters
+        self.results = OrderedDict()
+        self.stats = []
+        self.ious = {}  # ious between all gts and dts
+
+        self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
+        self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())
+
+    def _to_mask(self, anns, lvis):
+        for ann in anns:
+            rle = lvis.ann_to_rle(ann)
+            ann["segmentation"] = rle
+
+    def _prepare(self):
+        """Prepare self._gts and self._dts for evaluation based on params."""
+
+        cat_ids = self.params.cat_ids if self.params.cat_ids else None
+
+        gts = self.lvis_gt.load_anns(self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids))
+        dts = self.lvis_dt.load_anns(self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids))
+        # convert ground truth to mask if iou_type == 'segm'
+        if self.params.iou_type == "segm":
+            self._to_mask(gts, self.lvis_gt)
+            self._to_mask(dts, self.lvis_dt)
+
+        # set ignore flag
+        for gt in gts:
+            if "ignore" not in gt:
+                gt["ignore"] = 0
+
+        for gt in gts:
+            self._gts[gt["image_id"], gt["category_id"]].append(gt)
+
+        # For federated dataset evaluation we will filter out all dt for an
+        # image which belong to categories not present in gt and not present in
+        # the negative list for an image. In other words detector is not penalized
+        # for categories about which we don't have gt information about their
+        # presence or absence in an image.
+        img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids)
+        # per image map of categories not present in image
+        img_nl = {d["id"]: d["neg_category_ids"] for d in img_data}
+        # per image list of categories present in image
+        img_pl = defaultdict(set)
+        for ann in gts:
+            img_pl[ann["image_id"]].add(ann["category_id"])
+        # per image map of categoires which have missing gt. For these
+        # categories we don't penalize the detector for flase positives.
+        self.img_nel = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data}
+
+        for dt in dts:
+            img_id, cat_id = dt["image_id"], dt["category_id"]
+            if cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]:
+                continue
+            self._dts[img_id, cat_id].append(dt)
+
+        self.freq_groups = self._prepare_freq_group()
+
+    def _prepare_freq_group(self):
+        freq_groups = [[] for _ in self.params.img_count_lbl]
+        cat_data = self.lvis_gt.load_cats(self.params.cat_ids)
+        for idx, _cat_data in enumerate(cat_data):
+            frequency = _cat_data["frequency"]
+            freq_groups[self.params.img_count_lbl.index(frequency)].append(idx)
+        return freq_groups
+
+    def evaluate(self):
+        """
+        Run per image evaluation on given images and store results
+        (a list of dict) in self.eval_imgs.
+        """
+
+        self.params.img_ids = list(np.unique(self.params.img_ids))
+
+        if self.params.use_cats:
+            cat_ids = self.params.cat_ids
+        else:
+            cat_ids = [-1]
+
+        self._prepare()
+
+        self.ious = {
+            (img_id, cat_id): self.compute_iou(img_id, cat_id) for img_id in self.params.img_ids for cat_id in cat_ids
+        }
+
+        # loop through images, area range, max detection number
+        self.eval_imgs = [
+            self.evaluate_img(img_id, cat_id, area_rng)
+            for cat_id in cat_ids
+            for area_rng in self.params.area_rng
+            for img_id in self.params.img_ids
+        ]
+
+    def _get_gt_dt(self, img_id, cat_id):
+        """Create gt, dt which are list of anns/dets. If use_cats is true
+        only anns/dets corresponding to tuple (img_id, cat_id) will be
+        used. Else, all anns/dets in image are used and cat_id is not used.
+        """
+        if self.params.use_cats:
+            gt = self._gts[img_id, cat_id]
+            dt = self._dts[img_id, cat_id]
+        else:
+            gt = [_ann for _cat_id in self.params.cat_ids for _ann in self._gts[img_id, cat_id]]
+            dt = [_ann for _cat_id in self.params.cat_ids for _ann in self._dts[img_id, cat_id]]
+        return gt, dt
+
+    def compute_iou(self, img_id, cat_id):
+        gt, dt = self._get_gt_dt(img_id, cat_id)
+
+        if len(gt) == 0 and len(dt) == 0:
+            return []
+
+        # Sort detections in decreasing order of score.
+        idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
+        dt = [dt[i] for i in idx]
+
+        iscrowd = [int(False)] * len(gt)
+
+        if self.params.iou_type == "segm":
+            ann_type = "segmentation"
+        elif self.params.iou_type == "bbox":
+            ann_type = "bbox"
+        else:
+            raise ValueError("Unknown iou_type for iou computation.")
+        gt = [g[ann_type] for g in gt]
+        dt = [d[ann_type] for d in dt]
+
+        # compute iou between each dt and gt region
+        # will return array of shape len(dt), len(gt)
+        ious = mask_util.iou(dt, gt, iscrowd)
+        return ious
+
+    def evaluate_img(self, img_id, cat_id, area_rng):
+        """Perform evaluation for single category and image."""
+        gt, dt = self._get_gt_dt(img_id, cat_id)
+
+        if len(gt) == 0 and len(dt) == 0:
+            return None
+
+        # Add another filed _ignore to only consider anns based on area range.
+        for g in gt:
+            if g["ignore"] or (g["area"] < area_rng[0] or g["area"] > area_rng[1]):
+                g["_ignore"] = 1
+            else:
+                g["_ignore"] = 0
+
+        # Sort gt ignore last
+        gt_idx = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
+        gt = [gt[i] for i in gt_idx]
+
+        # Sort dt highest score first
+        dt_idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
+        dt = [dt[i] for i in dt_idx]
+
+        # load computed ious
+        ious = self.ious[img_id, cat_id][:, gt_idx] if len(self.ious[img_id, cat_id]) > 0 else self.ious[img_id, cat_id]
+
+        num_thrs = len(self.params.iou_thrs)
+        num_gt = len(gt)
+        num_dt = len(dt)
+
+        # Array to store the "id" of the matched dt/gt
+        gt_m = np.zeros((num_thrs, num_gt))
+        dt_m = np.zeros((num_thrs, num_dt))
+
+        gt_ig = np.array([g["_ignore"] for g in gt])
+        dt_ig = np.zeros((num_thrs, num_dt))
+
+        for iou_thr_idx, iou_thr in enumerate(self.params.iou_thrs):
+            if len(ious) == 0:
+                break
+
+            for dt_idx, _dt in enumerate(dt):
+                iou = min([iou_thr, 1 - 1e-10])
+                # information about best match so far (m=-1 -> unmatched)
+                # store the gt_idx which matched for _dt
+                m = -1
+                for gt_idx, _ in enumerate(gt):
+                    # if this gt already matched continue
+                    if gt_m[iou_thr_idx, gt_idx] > 0:
+                        continue
+                    # if _dt matched to reg gt, and on ignore gt, stop
+                    if m > -1 and gt_ig[m] == 0 and gt_ig[gt_idx] == 1:
+                        break
+                    # continue to next gt unless better match made
+                    if ious[dt_idx, gt_idx] < iou:
+                        continue
+                    # if match successful and best so far, store appropriately
+                    iou = ious[dt_idx, gt_idx]
+                    m = gt_idx
+
+                # No match found for _dt, go to next _dt
+                if m == -1:
+                    continue
+
+                # if gt to ignore for some reason update dt_ig.
+                # Should not be used in evaluation.
+                dt_ig[iou_thr_idx, dt_idx] = gt_ig[m]
+                # _dt match found, update gt_m, and dt_m with "id"
+                dt_m[iou_thr_idx, dt_idx] = gt[m]["id"]
+                gt_m[iou_thr_idx, m] = _dt["id"]
+
+        # For LVIS we will ignore any unmatched detection if that category was
+        # not exhaustively annotated in gt.
+        dt_ig_mask = [
+            d["area"] < area_rng[0] or d["area"] > area_rng[1] or d["category_id"] in self.img_nel[d["image_id"]]
+            for d in dt
+        ]
+        dt_ig_mask = np.array(dt_ig_mask).reshape((1, num_dt))  # 1 X num_dt
+        dt_ig_mask = np.repeat(dt_ig_mask, num_thrs, 0)  # num_thrs X num_dt
+        # Based on dt_ig_mask ignore any unmatched detection by updating dt_ig
+        dt_ig = np.logical_or(dt_ig, np.logical_and(dt_m == 0, dt_ig_mask))
+        # store results for given image and category
+        return {
+            "image_id": img_id,
+            "category_id": cat_id,
+            "area_rng": area_rng,
+            "dt_ids": [d["id"] for d in dt],
+            "gt_ids": [g["id"] for g in gt],
+            "dt_matches": dt_m,
+            "gt_matches": gt_m,
+            "dt_scores": [d["score"] for d in dt],
+            "gt_ignore": gt_ig,
+            "dt_ignore": dt_ig,
+        }
+
+    def accumulate(self):
+        """Accumulate per image evaluation results and store the result in
+        self.eval.
+        """
+
+        if not self.eval_imgs:
+            print("Warning: Please run evaluate first.")
+
+        if self.params.use_cats:
+            cat_ids = self.params.cat_ids
+        else:
+            cat_ids = [-1]
+
+        num_thrs = len(self.params.iou_thrs)
+        num_recalls = len(self.params.rec_thrs)
+        num_cats = len(cat_ids)
+        num_area_rngs = len(self.params.area_rng)
+        num_imgs = len(self.params.img_ids)
+
+        # -1 for absent categories
+        precision = -np.ones((num_thrs, num_recalls, num_cats, num_area_rngs))
+        recall = -np.ones((num_thrs, num_cats, num_area_rngs))
+
+        # Initialize dt_pointers
+        dt_pointers = {}
+        for cat_idx in range(num_cats):
+            dt_pointers[cat_idx] = {}
+            for area_idx in range(num_area_rngs):
+                dt_pointers[cat_idx][area_idx] = {}
+
+        # Per category evaluation
+        for cat_idx in range(num_cats):
+            Nk = cat_idx * num_area_rngs * num_imgs
+            for area_idx in range(num_area_rngs):
+                Na = area_idx * num_imgs
+                E = [self.eval_imgs[Nk + Na + img_idx] for img_idx in range(num_imgs)]
+                # Remove elements which are None
+                E = [e for e in E if e is not None]
+                if len(E) == 0:
+                    continue
+
+                # Append all scores: shape (N,)
+                dt_scores = np.concatenate([e["dt_scores"] for e in E], axis=0)
+                dt_ids = np.concatenate([e["dt_ids"] for e in E], axis=0)
+
+                dt_idx = np.argsort(-dt_scores, kind="mergesort")
+                dt_scores = dt_scores[dt_idx]
+                dt_ids = dt_ids[dt_idx]
+
+                dt_m = np.concatenate([e["dt_matches"] for e in E], axis=1)[:, dt_idx]
+                dt_ig = np.concatenate([e["dt_ignore"] for e in E], axis=1)[:, dt_idx]
+
+                gt_ig = np.concatenate([e["gt_ignore"] for e in E])
+                # num gt anns to consider
+                num_gt = np.count_nonzero(gt_ig == 0)
+
+                if num_gt == 0:
+                    continue
+
+                tps = np.logical_and(dt_m, np.logical_not(dt_ig))
+                fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig))
+
+                tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
+                fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
+
+                dt_pointers[cat_idx][area_idx] = {
+                    "dt_ids": dt_ids,
+                    "tps": tps,
+                    "fps": fps,
+                }
+
+                for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
+                    tp = np.array(tp)
+                    fp = np.array(fp)
+                    num_tp = len(tp)
+                    rc = tp / num_gt
+                    if num_tp:
+                        recall[iou_thr_idx, cat_idx, area_idx] = rc[-1]
+                    else:
+                        recall[iou_thr_idx, cat_idx, area_idx] = 0
+
+                    # np.spacing(1) ~= eps
+                    pr = tp / (fp + tp + np.spacing(1))
+                    pr = pr.tolist()
+
+                    # Replace each precision value with the maximum precision
+                    # value to the right of that recall level. This ensures
+                    # that the  calculated AP value will be less suspectable
+                    # to small variations in the ranking.
+                    for i in range(num_tp - 1, 0, -1):
+                        if pr[i] > pr[i - 1]:
+                            pr[i - 1] = pr[i]
+
+                    rec_thrs_insert_idx = np.searchsorted(rc, self.params.rec_thrs, side="left")
+
+                    pr_at_recall = [0.0] * num_recalls
+
+                    try:
+                        for _idx, pr_idx in enumerate(rec_thrs_insert_idx):
+                            pr_at_recall[_idx] = pr[pr_idx]
+                    except Exception:
+                        pass
+                    precision[iou_thr_idx, :, cat_idx, area_idx] = np.array(pr_at_recall)
+
+        self.eval = {
+            "params": self.params,
+            "counts": [num_thrs, num_recalls, num_cats, num_area_rngs],
+            "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+            "precision": precision,
+            "recall": recall,
+            "dt_pointers": dt_pointers,
+        }
+
+    def _summarize(self, summary_type, iou_thr=None, area_rng="all", freq_group_idx=None):
+        aidx = [idx for idx, _area_rng in enumerate(self.params.area_rng_lbl) if _area_rng == area_rng]
+
+        if summary_type == "ap":
+            s = self.eval["precision"]
+            if iou_thr is not None:
+                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
+                s = s[tidx]
+            if freq_group_idx is not None:
+                s = s[:, :, self.freq_groups[freq_group_idx], aidx]
+            else:
+                s = s[:, :, :, aidx]
+        else:
+            s = self.eval["recall"]
+            if iou_thr is not None:
+                tidx = np.where(iou_thr == self.params.iou_thrs)[0]
+                s = s[tidx]
+            s = s[:, :, aidx]
+
+        if len(s[s > -1]) == 0:
+            mean_s = -1
+        else:
+            mean_s = np.mean(s[s > -1])
+        return mean_s
+
+    def summarize(self):
+        """Compute and display summary metrics for evaluation results."""
+        if not self.eval:
+            raise RuntimeError("Please run accumulate() first.")
+
+        max_dets = self.params.max_dets
+
+        self.results["AP"] = self._summarize("ap")
+        self.results["AP50"] = self._summarize("ap", iou_thr=0.50)
+        self.results["AP75"] = self._summarize("ap", iou_thr=0.75)
+        self.results["APs"] = self._summarize("ap", area_rng="small")
+        self.results["APm"] = self._summarize("ap", area_rng="medium")
+        self.results["APl"] = self._summarize("ap", area_rng="large")
+        self.results["APr"] = self._summarize("ap", freq_group_idx=0)
+        self.results["APc"] = self._summarize("ap", freq_group_idx=1)
+        self.results["APf"] = self._summarize("ap", freq_group_idx=2)
+
+        self.stats = np.zeros((9,))
+        self.stats[0] = self._summarize("ap")
+        self.stats[1] = self._summarize("ap", iou_thr=0.50)
+        self.stats[2] = self._summarize("ap", iou_thr=0.75)
+        self.stats[3] = self._summarize("ap", area_rng="small")
+        self.stats[4] = self._summarize("ap", area_rng="medium")
+        self.stats[5] = self._summarize("ap", area_rng="large")
+        self.stats[6] = self._summarize("ap", freq_group_idx=0)
+        self.stats[7] = self._summarize("ap", freq_group_idx=1)
+        self.stats[8] = self._summarize("ap", freq_group_idx=2)
+
+        key = "AR@{}".format(max_dets)
+        self.results[key] = self._summarize("ar")
+
+        for area_rng in ["small", "medium", "large"]:
+            key = "AR{}@{}".format(area_rng[0], max_dets)
+            self.results[key] = self._summarize("ar", area_rng=area_rng)
+        _returned = self.print_results()
+        return _returned
+
+    def run(self):
+        """Wrapper function which calculates the results."""
+        self.evaluate()
+        self.accumulate()
+        self.summarize()
+
+    def print_results(self):
+        template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}"
+        out_strings = []
+        for key, value in self.results.items():
+            max_dets = self.params.max_dets
+            if "AP" in key:
+                title = "Average Precision"
+                _type = "(AP)"
+            else:
+                title = "Average Recall"
+                _type = "(AR)"
+
+            if len(key) > 2 and key[2].isdigit():
+                iou_thr = float(key[2:]) / 100
+                iou = "{:0.2f}".format(iou_thr)
+            else:
+                iou = "{:0.2f}:{:0.2f}".format(self.params.iou_thrs[0], self.params.iou_thrs[-1])
+
+            if len(key) > 2 and key[2] in ["r", "c", "f"]:
+                cat_group_name = key[2]
+            else:
+                cat_group_name = "all"
+
+            if len(key) > 2 and key[2] in ["s", "m", "l"]:
+                area_rng = key[2]
+            else:
+                area_rng = "all"
+
+            print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
+            out_strings.append(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
+        return out_strings
+
+    def get_results(self):
+        if not self.results:
+            print("Warning: results is empty. Call run().")
+        return self.results
+
+
+#################################################################
+# end of straight copy from lvis, just fixing constructor
+#################################################################
+
+
+class LvisEvaluator(object):
+    def __init__(self, lvis_gt, iou_types):
+        assert isinstance(iou_types, (list, tuple))
+        # lvis_gt = copy.deepcopy(lvis_gt)
+        self.lvis_gt = lvis_gt
+
+        self.iou_types = iou_types
+        self.coco_eval = {}
+        for iou_type in iou_types:
+            self.coco_eval[iou_type] = LVISEval(lvis_gt, iou_type=iou_type)
+
+        self.img_ids = []
+        self.eval_imgs = {k: [] for k in iou_types}
+
+    def update(self, predictions):
+        img_ids = list(np.unique(list(predictions.keys())))
+        self.img_ids.extend(img_ids)
+
+        for iou_type in self.iou_types:
+            results = self.prepare(predictions, iou_type)
+            lvis_dt = LVISResults(self.lvis_gt, results)
+            lvis_eval = self.coco_eval[iou_type]
+
+            lvis_eval.lvis_dt = lvis_dt
+            lvis_eval.params.img_ids = list(img_ids)
+            lvis_eval.evaluate()
+            eval_imgs = lvis_eval.eval_imgs
+            eval_imgs = np.asarray(eval_imgs).reshape(
+                len(lvis_eval.params.cat_ids), len(lvis_eval.params.area_rng), len(lvis_eval.params.img_ids)
+            )
+
+            self.eval_imgs[iou_type].append(eval_imgs)
+
+    def synchronize_between_processes(self):
+        for iou_type in self.iou_types:
+            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
+            create_common_lvis_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
+
+    def accumulate(self):
+        for lvis_eval in self.coco_eval.values():
+            lvis_eval.accumulate()
+
+    def summarize(self):
+        for iou_type, lvis_eval in self.coco_eval.items():
+            print("IoU metric: {}".format(iou_type))
+            lvis_eval.summarize()
+
+    def prepare(self, predictions, iou_type):
+        if iou_type == "bbox":
+            return self.prepare_for_lvis_detection(predictions)
+        elif iou_type == "segm":
+            return self.prepare_for_lvis_segmentation(predictions)
+        elif iou_type == "keypoints":
+            return self.prepare_for_lvis_keypoint(predictions)
+        else:
+            raise ValueError("Unknown iou type {}".format(iou_type))
+
+    def prepare_for_lvis_detection(self, predictions):
+        lvis_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            lvis_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "bbox": box,
+                        "score": scores[k],
+                    }
+                    for k, box in enumerate(boxes)
+                ]
+            )
+        return lvis_results
+
+    def prepare_for_lvis_segmentation(self, predictions):
+        lvis_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            scores = prediction["scores"]
+            labels = prediction["labels"]
+            masks = prediction["masks"]
+
+            masks = masks > 0.5
+
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            rles = [
+                mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
+            ]
+            for rle in rles:
+                rle["counts"] = rle["counts"].decode("utf-8")
+
+            lvis_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "segmentation": rle,
+                        "score": scores[k],
+                    }
+                    for k, rle in enumerate(rles)
+                ]
+            )
+        return lvis_results
+
+
+def _merge_lists(listA, listB, maxN, key):
+    result = []
+    indA, indB = 0, 0
+    while (indA < len(listA) or indB < len(listB)) and len(result) < maxN:
+        if (indB < len(listB)) and (indA >= len(listA) or key(listA[indA]) < key(listB[indB])):
+            result.append(listB[indB])
+            indB += 1
+        else:
+            result.append(listA[indA])
+            indA += 1
+    return result
+
+
+# Adapted from https://github.com/achalddave/large-vocab-devil/blob/9aaddc15b00e6e0d370b16743233e40d973cd53f/scripts/evaluate_ap_fixed.py
+class LvisEvaluatorFixedAP(object):
+    def __init__(self, gt: LVIS, topk=10000, fixed_ap=True):
+
+        self.results = []
+        self.by_cat = {}
+        self.gt = gt
+        self.topk = topk
+        self.fixed_ap = fixed_ap
+
+    def update(self, predictions):
+        cur_results = self.prepare(predictions)
+        if self.fixed_ap:
+            by_cat = defaultdict(list)
+            for ann in cur_results:
+                by_cat[ann["category_id"]].append(ann)
+
+            for cat, cat_anns in by_cat.items():
+                if cat not in self.by_cat:
+                    self.by_cat[cat] = []
+
+                cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]
+                self.by_cat[cat] = _merge_lists(self.by_cat[cat], cur, self.topk, key=lambda x: x["score"])
+        else:
+            by_id = defaultdict(list)
+            for ann in cur_results:
+                by_id[ann["image_id"]].append(ann)
+
+            for id_anns in by_id.values():
+                self.results.extend(sorted(id_anns, key=lambda x: x["score"], reverse=True)[:300])
+
+    def synchronize_between_processes(self):
+        if self.fixed_ap:
+            all_cats = dist.all_gather(self.by_cat)
+            self.by_cat = defaultdict(list)
+            for cats in all_cats:
+                for cat, cat_anns in cats.items():
+                    self.by_cat[cat].extend(cat_anns)
+        else:
+            self.results = sum(dist.all_gather(self.results), [])
+
+    def prepare(self, predictions):
+        lvis_results = []
+        for original_id, prediction in predictions:
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            lvis_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "bbox": box,
+                        "score": scores[k],
+                    }
+                    for k, box in enumerate(boxes)
+                ]
+            )
+        return lvis_results
+
+    def summarize(self):
+        if not dist.is_main_process():
+            return
+
+        if self.fixed_ap:
+            return self._summarize_fixed()
+        else:
+            return self._summarize_standard()
+
+    def _summarize_standard(self):
+        results = LVISResults(self.gt, self.results)
+        lvis_eval = LVISEval(self.gt, results, iou_type="bbox")
+        lvis_eval.run()
+        lvis_eval.print_results()
+
+    def _summarize_fixed(self):
+        results = []
+
+        missing_dets_cats = set()
+        for cat, cat_anns in self.by_cat.items():
+            if len(cat_anns) < self.topk:
+                missing_dets_cats.add(cat)
+            results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk])
+        if missing_dets_cats:
+            print(
+                f"\n===\n"
+                f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n"
+                f"Outputting {self.topk} detections for each class will improve AP further.\n"
+                f"If using detectron2, please use the lvdevil/infer_topk.py script to "
+                f"output a results file with {self.topk} detections for each class.\n"
+                f"==="
+            )
+
+        results = LVISResults(self.gt, results, max_dets=-1)
+        lvis_eval = LVISEval(self.gt, results, iou_type="bbox")
+        params = lvis_eval.params
+        params.max_dets = -1  # No limit on detections per image.
+        lvis_eval.run()
+        scores = lvis_eval.print_results()
+        metrics = {k: v for k, v in lvis_eval.results.items() if k.startswith("AP")}
+        print("copypaste: %s,%s", ",".join(map(str, metrics.keys())), "path")
+        return scores
+
+
+class LvisDumper(object):
+    def __init__(self, topk=10000, fixed_ap=True, out_path="lvis_eval"):
+
+        self.results = []
+        self.by_cat = {}
+        self.topk = topk
+        self.fixed_ap = fixed_ap
+        self.out_path = out_path
+        if dist.is_main_process():
+            if not os.path.exists(self.out_path):
+                os.mkdir(self.out_path)
+
+    def update(self, predictions):
+        cur_results = self.prepare(predictions)
+        if self.fixed_ap:
+            by_cat = defaultdict(list)
+            for ann in cur_results:
+                by_cat[ann["category_id"]].append(ann)
+
+            for cat, cat_anns in by_cat.items():
+                if cat not in self.by_cat:
+                    self.by_cat[cat] = []
+
+                cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]
+                self.by_cat[cat] = _merge_lists(self.by_cat[cat], cur, self.topk, key=lambda x: x["score"])
+        else:
+            by_id = defaultdict(list)
+            for ann in cur_results:
+                by_id[ann["image_id"]].append(ann)
+
+            for id_anns in by_id.values():
+                self.results.extend(sorted(id_anns, key=lambda x: x["score"], reverse=True)[:300])
+
+    def synchronize_between_processes(self):
+        if self.fixed_ap:
+            all_cats = dist.all_gather(self.by_cat)
+            self.by_cat = defaultdict(list)
+            for cats in all_cats:
+                for cat, cat_anns in cats.items():
+                    self.by_cat[cat].extend(cat_anns)
+        else:
+            self.results = sum(dist.all_gather(self.results), [])
+
+    def prepare(self, predictions):
+        lvis_results = []
+        for original_id, prediction in predictions:
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            lvis_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "bbox": box,
+                        "score": scores[k],
+                    }
+                    for k, box in enumerate(boxes)
+                ]
+            )
+        return lvis_results
+
+    def summarize(self):
+        if not dist.is_main_process():
+            return
+
+        if self.fixed_ap:
+            self._summarize_fixed()
+        else:
+            self._summarize_standard()
+
+    def _summarize_standard(self):
+        json_path = os.path.join(self.out_path, "results.json")
+        print("dumping to ", json_path)
+        with open(json_path, "w") as f:
+            json.dump(self.results, f)
+
+        print("dumped")
+
+    def _summarize_fixed(self):
+        results = []
+
+        missing_dets_cats = set()
+        for cat, cat_anns in self.by_cat.items():
+            if len(cat_anns) < self.topk:
+                missing_dets_cats.add(cat)
+            results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk])
+        if missing_dets_cats:
+            print(
+                f"\n===\n"
+                f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n"
+                f"Outputting {self.topk} detections for each class will improve AP further.\n"
+                f"If using detectron2, please use the lvdevil/infer_topk.py script to "
+                f"output a results file with {self.topk} detections for each class.\n"
+                f"==="
+            )
+
+        json_path = os.path.join(self.out_path, "results.json")
+        print("dumping to ", json_path)
+        with open(json_path, "w") as f:
+            json.dump(results, f)
+
+        print("dumped")
+
+
+def convert_to_xywh(boxes):
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
+
+
+def create_common_lvis_eval(lvis_eval, img_ids, eval_imgs):
+    img_ids, eval_imgs = merge(img_ids, eval_imgs)
+    img_ids = list(img_ids)
+    eval_imgs = list(eval_imgs.flatten())
+
+    lvis_eval.eval_imgs = eval_imgs
+    lvis_eval.params.img_ids = img_ids
+
+def lvis_evaluation():
+    pass
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/od_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ea88105b4480c4398ad6ab0864bd291fdf47ff
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py
@@ -0,0 +1,20 @@
+from .od_eval import do_od_evaluation
+
+
+def od_to_grounding_evaluation(
+        dataset,
+        predictions,
+        output_folder,
+        box_only=False,
+        iou_types=("bbox",),
+        expected_results=(),
+        expected_results_sigma_tol=4, ):
+    return do_od_evaluation(
+        dataset=dataset,
+        predictions=predictions,
+        box_only=box_only,
+        output_folder=output_folder,
+        iou_types=iou_types,
+        expected_results=expected_results,
+        expected_results_sigma_tol=expected_results_sigma_tol,
+    )
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..443e0f7a2c70a48fedb54c9902a93a3ded15fcd0
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py
@@ -0,0 +1,532 @@
+import logging
+import tempfile
+import os
+import torch
+import numpy as np
+import json
+
+from collections import OrderedDict
+from tqdm import tqdm
+
+from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+
+
+def do_od_evaluation(
+        dataset,
+        predictions,
+        box_only,
+        output_folder,
+        iou_types,
+        expected_results,
+        expected_results_sigma_tol,
+):
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+
+    if box_only:
+        logger.info("Evaluating bbox proposals")
+        if dataset.coco is None and output_folder:
+            json_results = prepare_for_tsv_detection(predictions, dataset)
+            with open(os.path.join(output_folder, "box_proposals.json"), "w") as f:
+                json.dump(json_results, f)
+            return None
+        areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+        res = COCOResults("box_proposal")
+        for limit in [100, 1000]:
+            for area, suffix in areas.items():
+                stats = evaluate_box_proposals(
+                    predictions, dataset, area=area, limit=limit
+                )
+                key = "AR{}@{:d}".format(suffix, limit)
+                res.results["box_proposal"][key] = stats["ar"].item()
+        logger.info(res)
+        check_expected_results(res, expected_results, expected_results_sigma_tol)
+        if output_folder:
+            torch.save(res, os.path.join(output_folder, "box_proposals.pth"))
+        return res, predictions
+    logger.info("Preparing results for COCO format")
+    coco_results = {}
+    if "bbox" in iou_types:
+        logger.info("Preparing bbox results")
+        if dataset.coco is None:
+            coco_results["bbox"] = prepare_for_tsv_detection(predictions, dataset)
+        else:
+            coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset)
+    if "segm" in iou_types:
+        logger.info("Preparing segm results")
+        coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset)
+    if 'keypoints' in iou_types:
+        logger.info('Preparing keypoints results')
+        coco_results['keypoints'] = prepare_for_coco_keypoint(predictions, dataset)
+
+    results = COCOResults(*iou_types)
+    logger.info("Evaluating predictions")
+    for iou_type in iou_types:
+        with tempfile.NamedTemporaryFile() as f:
+            file_path = f.name
+            if output_folder:
+                file_path = os.path.join(output_folder, iou_type + ".json")
+            if dataset.coco:
+                res = evaluate_predictions_on_coco(
+                    dataset.coco, coco_results[iou_type], file_path, iou_type
+                )
+                results.update(res)
+            elif output_folder:
+                with open(file_path, "w") as f:
+                    json.dump(coco_results[iou_type], f)
+
+    logger.info(results)
+    check_expected_results(results, expected_results, expected_results_sigma_tol)
+    if output_folder:
+        torch.save(results, os.path.join(output_folder, "coco_results.pth"))
+    return results, coco_results
+
+
+def prepare_for_tsv_detection(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    proposal_results = []
+    image_list = []
+    for im_id, prediction in enumerate(predictions):
+        image_info = dataset.get_img_info(im_id)
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_id = image_info["id"]
+        image_width = image_info["width"]
+        image_height = image_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert("xywh")
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+        if prediction.has_field("centers"):
+            centers = prediction.get_field("centers")
+        else:
+            centers = None
+
+        for k, box in enumerate(boxes):
+            proposal = {
+                "image_id": image_id,
+                "category_id": labels[k],
+                "bbox": box,
+                "score": scores[k],
+                "area": image_width * image_height,
+                "iscrowd": 0,
+            }
+            if centers is not None:
+                proposal.update(center=centers[k].tolist())
+            proposal_results.append(proposal)
+
+        image_list.append(image_info)
+
+        # categories = [{'supercategory': 'proposal', 'id': 0, 'name': 'proposal'}]
+    return dict(images=image_list, annotations=proposal_results)
+
+
+def prepare_for_coco_detection(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert("xywh")
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+
+        for k, box in enumerate(boxes):
+            if labels[k] in dataset.contiguous_category_id_to_json_id:
+                coco_results.append(
+                    {
+                        "image_id": original_id,
+                        "category_id": dataset.contiguous_category_id_to_json_id[labels[k]],
+                        "bbox": box,
+                        "score": scores[k],
+                    })
+
+    return coco_results
+
+
+def prepare_for_coco_segmentation(predictions, dataset):
+    import pycocotools.mask as mask_util
+    import numpy as np
+
+    masker = Masker(threshold=0.5, padding=1)
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in tqdm(enumerate(predictions)):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+        masks = prediction.get_field("mask")
+        # t = time.time()
+        # Masker is necessary only if masks haven't been already resized.
+        if list(masks.shape[-2:]) != [image_height, image_width]:
+            masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
+            masks = masks[0]
+        # logger.info('Time mask: {}'.format(time.time() - t))
+        # prediction = prediction.convert('xywh')
+
+        # boxes = prediction.bbox.tolist()
+        scores = prediction.get_field("scores").tolist()
+        labels = prediction.get_field("labels").tolist()
+
+        # rles = prediction.get_field('mask')
+
+        rles = [
+            mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
+            for mask in masks
+        ]
+        for rle in rles:
+            rle["counts"] = rle["counts"].decode("utf-8")
+
+        mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
+
+        coco_results.extend(
+            [
+                {
+                    "image_id": original_id,
+                    "category_id": mapped_labels[k],
+                    "segmentation": rle,
+                    "score": scores[k],
+                }
+                for k, rle in enumerate(rles)
+            ]
+        )
+    return coco_results
+
+
+def prepare_for_coco_keypoint(predictions, dataset):
+    # assert isinstance(dataset, COCODataset)
+    coco_results = []
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+        if len(prediction.bbox) == 0:
+            continue
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]['width']
+        image_height = dataset.coco.imgs[original_id]['height']
+        prediction = prediction.resize((image_width, image_height))
+        prediction = prediction.convert('xywh')
+
+        boxes = prediction.bbox.tolist()
+        scores = prediction.get_field('scores').tolist()
+        labels = prediction.get_field('labels').tolist()
+        keypoints = prediction.get_field('keypoints')
+        keypoints = keypoints.resize((image_width, image_height))
+        keypoints = keypoints.to_coco_format()
+
+        mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
+
+        coco_results.extend([{
+            'image_id': original_id,
+            'category_id': mapped_labels[k],
+            'keypoints': keypoint,
+            'score': scores[k]} for k, keypoint in enumerate(keypoints)])
+    return coco_results
+
+
+# inspired from Detectron
+def evaluate_box_proposals(
+        predictions, dataset, thresholds=None, area="all", limit=None
+):
+    """Evaluate detection proposal recall metrics. This function is a much
+    faster alternative to the official COCO API recall evaluation code. However,
+    it produces slightly different results.
+    """
+    # Record max overlap value for each gt box
+    # Return vector of overlap values
+    areas = {
+        "all": 0,
+        "small": 1,
+        "medium": 2,
+        "large": 3,
+        "96-128": 4,
+        "128-256": 5,
+        "256-512": 6,
+        "512-inf": 7,
+    }
+    area_ranges = [
+        [0 ** 2, 1e5 ** 2],  # all
+        [0 ** 2, 32 ** 2],  # small
+        [32 ** 2, 96 ** 2],  # medium
+        [96 ** 2, 1e5 ** 2],  # large
+        [96 ** 2, 128 ** 2],  # 96-128
+        [128 ** 2, 256 ** 2],  # 128-256
+        [256 ** 2, 512 ** 2],  # 256-512
+        [512 ** 2, 1e5 ** 2],
+    ]  # 512-inf
+    assert area in areas, "Unknown area range: {}".format(area)
+    area_range = area_ranges[areas[area]]
+    gt_overlaps = []
+    num_pos = 0
+
+    for image_id, prediction in enumerate(predictions):
+        original_id = dataset.id_to_img_map[image_id]
+
+        # TODO replace with get_img_info?
+        image_width = dataset.coco.imgs[original_id]["width"]
+        image_height = dataset.coco.imgs[original_id]["height"]
+        prediction = prediction.resize((image_width, image_height))
+
+        # sort predictions in descending order
+        # TODO maybe remove this and make it explicit in the documentation
+        if prediction.has_field("objectness"):
+            inds = prediction.get_field("objectness").sort(descending=True)[1]
+        else:
+            inds = prediction.get_field("scores").sort(descending=True)[1]
+        prediction = prediction[inds]
+
+        ann_ids = dataset.coco.getAnnIds(imgIds=original_id)
+        anno = dataset.coco.loadAnns(ann_ids)
+        gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0]
+        gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4)  # guard against no boxes
+        gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert(
+            "xyxy"
+        )
+        gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
+
+        if len(gt_boxes) == 0:
+            continue
+
+        valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+        gt_boxes = gt_boxes[valid_gt_inds]
+
+        num_pos += len(gt_boxes)
+
+        if len(gt_boxes) == 0:
+            continue
+
+        if len(prediction) == 0:
+            continue
+
+        if limit is not None and len(prediction) > limit:
+            prediction = prediction[:limit]
+
+        overlaps = boxlist_iou(prediction, gt_boxes)
+
+        _gt_overlaps = torch.zeros(len(gt_boxes))
+        for j in range(min(len(prediction), len(gt_boxes))):
+            # find which proposal box maximally covers each gt box
+            # and get the iou amount of coverage for each gt box
+            max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+            # find which gt box is 'best' covered (i.e. 'best' = most iou)
+            gt_ovr, gt_ind = max_overlaps.max(dim=0)
+            assert gt_ovr >= 0
+            # find the proposal box that covers the best covered gt box
+            box_ind = argmax_overlaps[gt_ind]
+            # record the iou coverage of this gt box
+            _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+            assert _gt_overlaps[j] == gt_ovr
+            # mark the proposal box and the gt box as used
+            overlaps[box_ind, :] = -1
+            overlaps[:, gt_ind] = -1
+
+        # append recorded iou coverage level
+        gt_overlaps.append(_gt_overlaps)
+
+    if len(gt_overlaps) == 0:
+        return {
+            "ar": torch.zeros(1),
+            "recalls": torch.zeros(1),
+            "thresholds": thresholds,
+            "gt_overlaps": gt_overlaps,
+            "num_pos": num_pos,
+        }
+
+    gt_overlaps = torch.cat(gt_overlaps, dim=0)
+    gt_overlaps, _ = torch.sort(gt_overlaps)
+
+    if thresholds is None:
+        step = 0.05
+        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+    recalls = torch.zeros_like(thresholds)
+    # compute recall for each iou threshold
+    for i, t in enumerate(thresholds):
+        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+    # ar = 2 * np.trapz(recalls, thresholds)
+    ar = recalls.mean()
+    return {
+        "ar": ar,
+        "recalls": recalls,
+        "thresholds": thresholds,
+        "gt_overlaps": gt_overlaps,
+        "num_pos": num_pos,
+    }
+
+
+def evaluate_predictions_on_coco(
+        coco_gt, coco_results, json_result_file, iou_type="bbox"
+):
+    import json
+
+    with open(json_result_file, "w") as f:
+        json.dump(coco_results, f)
+
+    from pycocotools.coco import COCO
+    from pycocotools.cocoeval import COCOeval
+
+    coco_dt = coco_gt.loadRes(str(json_result_file)) if coco_results else COCO()
+
+    # coco_dt = coco_gt.loadRes(coco_results)
+    if iou_type == 'keypoints':
+        coco_gt = filter_valid_keypoints(coco_gt, coco_dt)
+    coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
+    coco_eval.evaluate()
+    coco_eval.accumulate()
+    coco_eval.summarize()
+    if iou_type == 'bbox':
+        summarize_per_category(coco_eval, json_result_file.replace('.json', '.csv'))
+    return coco_eval
+
+
+def summarize_per_category(coco_eval, csv_output=None):
+    '''
+    Compute and display summary metrics for evaluation results.
+    Note this functin can *only* be applied on the default parameter setting
+    '''
+
+    def _summarize(iouThr=None, areaRng='all', maxDets=100):
+        p = coco_eval.params
+        titleStr = 'Average Precision'
+        typeStr = '(AP)'
+        iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
+            if iouThr is None else '{:0.2f}'.format(iouThr)
+        result_str = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ], '. \
+            format(titleStr, typeStr, iouStr, areaRng, maxDets)
+
+        aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+        mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+
+        # dimension of precision: [TxRxKxAxM]
+        s = coco_eval.eval['precision']
+        # IoU
+        if iouThr is not None:
+            t = np.where(iouThr == p.iouThrs)[0]
+            s = s[t]
+        s = s[:, :, :, aind, mind]
+
+        if len(s[s > -1]) == 0:
+            mean_s = -1
+        else:
+            mean_s = np.mean(s[s > -1])
+            # cacluate AP(average precision) for each category
+            num_classes = len(p.catIds)
+            avg_ap = 0.0
+            for i in range(0, num_classes):
+                result_str += '{}, '.format(np.mean(s[:, :, i, :]))
+                avg_ap += np.mean(s[:, :, i, :])
+            result_str += ('{} \n'.format(avg_ap / num_classes))
+        return result_str
+
+    id2name = {}
+    for _, cat in coco_eval.cocoGt.cats.items():
+        id2name[cat['id']] = cat['name']
+    title_str = 'metric, '
+    for cid in coco_eval.params.catIds:
+        title_str += '{}, '.format(id2name[cid])
+    title_str += 'avg \n'
+
+    results = [title_str]
+    results.append(_summarize())
+    results.append(_summarize(iouThr=.5, maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='small', maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='medium', maxDets=coco_eval.params.maxDets[2]))
+    results.append(_summarize(areaRng='large', maxDets=coco_eval.params.maxDets[2]))
+
+    with open(csv_output, 'w') as f:
+        for result in results:
+            f.writelines(result)
+
+
+def filter_valid_keypoints(coco_gt, coco_dt):
+    kps = coco_dt.anns[1]['keypoints']
+    for id, ann in coco_gt.anns.items():
+        ann['keypoints'][2::3] = [a * b for a, b in zip(ann['keypoints'][2::3], kps[2::3])]
+        ann['num_keypoints'] = sum(ann['keypoints'][2::3])
+    return coco_gt
+
+
+class COCOResults(object):
+    METRICS = {
+        "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+        "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+        "box_proposal": [
+            "AR@100",
+            "ARs@100",
+            "ARm@100",
+            "ARl@100",
+            "AR@1000",
+            "ARs@1000",
+            "ARm@1000",
+            "ARl@1000",
+        ],
+        "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
+    }
+
+    def __init__(self, *iou_types):
+        allowed_types = ("box_proposal", "bbox", "segm", "keypoints")
+        assert all(iou_type in allowed_types for iou_type in iou_types)
+        results = OrderedDict()
+        for iou_type in iou_types:
+            results[iou_type] = OrderedDict(
+                [(metric, -1) for metric in COCOResults.METRICS[iou_type]]
+            )
+        self.results = results
+
+    def update(self, coco_eval):
+        if coco_eval is None:
+            return
+        from pycocotools.cocoeval import COCOeval
+
+        assert isinstance(coco_eval, COCOeval)
+        s = coco_eval.stats
+        iou_type = coco_eval.params.iouType
+        res = self.results[iou_type]
+        metrics = COCOResults.METRICS[iou_type]
+        for idx, metric in enumerate(metrics):
+            res[metric] = s[idx]
+
+    def __repr__(self):
+        # TODO make it pretty
+        return repr(self.results)
+
+
+def check_expected_results(results, expected_results, sigma_tol):
+    if not expected_results:
+        return
+
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    for task, metric, (mean, std) in expected_results:
+        actual_val = results.results[task][metric]
+        lo = mean - sigma_tol * std
+        hi = mean + sigma_tol * std
+        ok = (lo < actual_val) and (actual_val < hi)
+        msg = (
+            "{} > {} sanity check (actual vs. expected): "
+            "{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})"
+        ).format(task, metric, actual_val, mean, std, lo, hi)
+        if not ok:
+            msg = "FAIL: " + msg
+            logger.error(msg)
+        else:
+            msg = "PASS: " + msg
+            logger.info(msg)
+
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef18b3e5e9b007018fd7c839c7d053c48c2984d3
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py
@@ -0,0 +1,16 @@
+import logging
+
+from .vg_eval import do_vg_evaluation
+
+
+def vg_evaluation(dataset, predictions, output_folder, box_only, eval_attributes=False, **_):
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    logger.info("performing vg evaluation, ignored iou_types.")
+    return do_vg_evaluation(
+        dataset=dataset,
+        predictions=predictions,
+        output_folder=output_folder,
+        box_only=box_only,
+        eval_attributes=eval_attributes,
+        logger=logger,
+    )
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb20fc9f69f1d70efa65eb9e88bab95d438f2b51
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py
@@ -0,0 +1,672 @@
+# A modification version from chainercv repository.
+# (See https://github.com/chainer/chainercv/blob/master/chainercv/evaluations/eval_detection_voc.py)
+from __future__ import division
+
+import os
+from collections import OrderedDict
+import numpy as np
+import torch
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou, getUnionBBox
+
+
+# inspired from Detectron
+def evaluate_box_proposals(
+    predictions, dataset, thresholds=None, area="all", limit=None
+):
+    """Evaluate detection proposal recall metrics. This function is a much
+    faster alternative to the official COCO API recall evaluation code. However,
+    it produces slightly different results.
+    """
+    # Record max overlap value for each gt box
+    # Return vector of overlap values
+    areas = {
+        "all": 0,
+        "small": 1,
+        "medium": 2,
+        "large": 3,
+        "96-128": 4,
+        "128-256": 5,
+        "256-512": 6,
+        "512-inf": 7,
+    }
+    area_ranges = [
+        [0 ** 2, 1e5 ** 2],  # all
+        [0 ** 2, 32 ** 2],  # small
+        [32 ** 2, 96 ** 2],  # medium
+        [96 ** 2, 1e5 ** 2],  # large
+        [96 ** 2, 128 ** 2],  # 96-128
+        [128 ** 2, 256 ** 2],  # 128-256
+        [256 ** 2, 512 ** 2],  # 256-512
+        [512 ** 2, 1e5 ** 2],
+    ]  # 512-inf
+    assert area in areas, "Unknown area range: {}".format(area)
+    area_range = area_ranges[areas[area]]
+    gt_overlaps = []
+    num_pos = 0
+
+    for image_id, prediction in enumerate(predictions):
+        img_info = dataset.get_img_info(image_id)
+        image_width = img_info["width"]
+        image_height = img_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+
+        # deal with ground truth
+        gt_boxes = dataset.get_groundtruth(image_id)
+        # filter out the field "relations"
+        gt_boxes = gt_boxes.copy_with_fields(['attributes', 'labels'])
+        gt_areas = gt_boxes.area()
+
+        if len(gt_boxes) == 0:
+            continue
+
+        valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+        gt_boxes = gt_boxes[valid_gt_inds]
+
+        num_pos += len(gt_boxes)
+
+        if len(gt_boxes) == 0:
+            continue
+
+        # sort predictions in descending order
+        # TODO maybe remove this and make it explicit in the documentation
+        _gt_overlaps = torch.zeros(len(gt_boxes))
+        if len(prediction) == 0:
+            gt_overlaps.append(_gt_overlaps)
+            continue
+        if "objectness" in prediction.extra_fields:
+            inds = prediction.get_field("objectness").sort(descending=True)[1]
+        elif "scores" in prediction.extra_fields:
+            inds = prediction.get_field("scores").sort(descending=True)[1]
+        else:
+            raise ValueError("Neither objectness nor scores is in the extra_fields!")
+        prediction = prediction[inds]
+
+        if limit is not None and len(prediction) > limit:
+            prediction = prediction[:limit]
+
+        overlaps = boxlist_iou(prediction, gt_boxes)
+
+        for j in range(min(len(prediction), len(gt_boxes))):
+            # find which proposal box maximally covers each gt box
+            # and get the iou amount of coverage for each gt box
+            max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+            # find which gt box is 'best' covered (i.e. 'best' = most iou)
+            gt_ovr, gt_ind = max_overlaps.max(dim=0)
+            assert gt_ovr >= 0
+            # find the proposal box that covers the best covered gt box
+            box_ind = argmax_overlaps[gt_ind]
+            # record the iou coverage of this gt box
+            _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+            assert _gt_overlaps[j] == gt_ovr
+            # mark the proposal box and the gt box as used
+            overlaps[box_ind, :] = -1
+            overlaps[:, gt_ind] = -1
+
+        # append recorded iou coverage level
+        gt_overlaps.append(_gt_overlaps)
+    gt_overlaps = torch.cat(gt_overlaps, dim=0)
+    gt_overlaps, _ = torch.sort(gt_overlaps)
+
+    if thresholds is None:
+        step = 0.05
+        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+    recalls = torch.zeros_like(thresholds)
+    # compute recall for each iou threshold
+    for i, t in enumerate(thresholds):
+        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+    # ar = 2 * np.trapz(recalls, thresholds)
+    ar = recalls.mean()
+    return {
+        "ar": ar,
+        "recalls": recalls,
+        "thresholds": thresholds,
+        "gt_overlaps": gt_overlaps,
+        "num_pos": num_pos,
+    }
+
+
+class VGResults(object):
+    METRICS = {
+        "bbox": ["AP",],
+        "segm": ["AP",],
+        "box_proposal": ["AR@100",],
+    }
+
+    def __init__(self, iou_type, value):
+        allowed_types = ("box_proposal", "bbox", "segm", "keypoints")
+        assert iou_type in allowed_types
+        results = OrderedDict()
+        results[iou_type] = OrderedDict([(metric, value) for metric in VGResults.METRICS[iou_type]])
+        self.results = results
+
+
+def do_vg_evaluation(dataset, predictions, output_folder, box_only, eval_attributes, logger, save_predictions=True):
+    # TODO need to make the use_07_metric format available
+    # for the user to choose
+    # we use int for box_only. 0: False, 1: box for RPN, 2: box for object detection, 
+    if box_only:
+        if box_only == 1:
+            limits = [100, 1000]
+        elif box_only == 2:
+            limits = [36, 99]
+        else:
+            raise ValueError("box_only can be either 0/1/2, but get {0}".format(box_only))
+        areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+        result = {}
+        for area, suffix in areas.items():
+            for limit in limits:
+                logger.info("Evaluating bbox proposals@{:d}".format(limit))
+                stats = evaluate_box_proposals(
+                    predictions, dataset, area=area, limit=limit
+                )
+                key_ar = "AR{}@{:d}".format(suffix, limit)
+                key_num_pos = "num_pos{}@{:d}".format(suffix, limit)
+                result[key_num_pos] = stats["num_pos"]
+                result[key_ar] = stats["ar"].item()
+                key_recalls = "Recalls{}@{:d}".format(suffix, limit)
+                # result[key_recalls] = stats["recalls"]
+                print(key_recalls, stats["recalls"])
+                print(key_ar, "ar={:.4f}".format(result[key_ar]))
+                print(key_num_pos, "num_pos={:d}".format(result[key_num_pos]))
+                if limit != 1000 and dataset.relation_on:
+                    # if True:
+                    # relation @ 1000 (all and large) takes about 2 hs to compute
+                    # relation pair evaluation
+                    logger.info("Evaluating relation proposals@{:d}".format(limit))
+                    stats = evaluate_box_proposals_for_relation(
+                        predictions, dataset, area=area, limit=limit
+                    )
+                    key_ar = "AR{}@{:d}_for_relation".format(suffix, limit)
+                    key_num_pos = "num_pos{}@{:d}_for_relation".format(suffix, limit)
+                    result[key_num_pos] = stats["num_pos"]
+                    result[key_ar] = stats["ar"].item()
+                    # key_recalls = "Recalls{}@{:d}_for_relation".format(suffix, limit)
+                    # result[key_recalls] = stats["recalls"]
+                    print(key_ar, "ar={:.4f}".format(result[key_ar]))
+                    print(key_num_pos, "num_pos={:d}".format(result[key_num_pos]))
+        logger.info(result)
+        # check_expected_results(result, expected_results, expected_results_sigma_tol)
+        if output_folder and save_predictions:
+            if box_only == 1:
+                torch.save(result, os.path.join(output_folder, "rpn_proposals.pth"))
+            elif box_only == 2:
+                torch.save(result, os.path.join(output_folder, "box_proposals.pth"))
+            else:
+                raise ValueError("box_only can be either 0/1/2, but get {0}".format(box_only))
+        return VGResults('box_proposal', result["AR@100"]), {"box_proposal": result}
+
+    pred_boxlists = []
+    gt_boxlists = []
+    for image_id, prediction in enumerate(predictions):
+        img_info = dataset.get_img_info(image_id)
+        if len(prediction) == 0:
+            continue
+        image_width = img_info["width"]
+        image_height = img_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+        pred_boxlists.append(prediction)
+
+        gt_boxlist = dataset.get_groundtruth(image_id)
+        gt_boxlists.append(gt_boxlist)
+    if eval_attributes:
+        classes = dataset.attributes
+    else:
+        classes = dataset.classes
+    result = eval_detection_voc(
+        pred_boxlists=pred_boxlists,
+        gt_boxlists=gt_boxlists,
+        classes=classes,
+        iou_thresh=0.5,
+        eval_attributes=eval_attributes,
+        use_07_metric=False,
+    )
+    result_str = "mAP: {:.4f}\n".format(result["map"])
+    logger.info(result_str)
+    for i, ap in enumerate(result["ap"]):
+        # if i == 0:  # skip background
+        #     continue
+        # we skipped background in result['ap'], so we need to use i+1
+        if eval_attributes:
+            result_str += "{:<16}: {:.4f}\n".format(
+                dataset.map_attribute_id_to_attribute_name(i+1), ap
+            )
+        else:
+            result_str += "{:<16}: {:.4f}\n".format(
+                dataset.map_class_id_to_class_name(i+1), ap
+            )
+    # return mAP and weighted mAP
+    vg_result = VGResults('bbox', result["map"])
+    if eval_attributes:
+        if output_folder and save_predictions:
+            with open(os.path.join(output_folder, "result_attr.txt"), "w") as fid:
+                fid.write(result_str)
+        return vg_result, {"attr": {"map": result["map"], "weighted map": result["weighted map"]}}
+    else:
+        if output_folder and save_predictions:
+            with open(os.path.join(output_folder, "result_obj.txt"), "w") as fid:
+                fid.write(result_str)
+        return vg_result, {"obj": {"map": result["map"], "weighted map": result["weighted map"]}},
+
+
+def eval_detection_voc(pred_boxlists, gt_boxlists, classes, iou_thresh=0.5, eval_attributes=False, use_07_metric=False):
+    """Evaluate on voc dataset.
+    Args:
+        pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields.
+        gt_boxlists(list[BoxList]): ground truth boxlist, has labels field.
+        iou_thresh: iou thresh
+        use_07_metric: boolean
+    Returns:
+        dict represents the results
+    """
+    assert len(gt_boxlists) == len(
+        pred_boxlists
+    ), "Length of gt and pred lists need to be same."
+
+    aps = []
+    nposs = []
+    thresh = []
+
+    for i, classname in enumerate(classes):
+        if classname == "__background__" or classname == "__no_attribute__":
+            continue
+        rec, prec, ap, scores, npos = calc_detection_voc_prec_rec(pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, \
+                                                                  classindex=i, iou_thresh=iou_thresh,
+                                                                  eval_attributes=eval_attributes,
+                                                                  use_07_metric=use_07_metric)
+        # Determine per class detection thresholds that maximise f score
+        # if npos > 1:
+        if npos > 1 and type(scores) != np.int:
+            f = np.nan_to_num((prec * rec) / (prec + rec))
+            thresh += [scores[np.argmax(f)]]
+        else:
+            thresh += [0]
+        aps += [ap]
+        nposs += [float(npos)]
+        # print('AP for {} = {:.4f} (npos={:,})'.format(classname, ap, npos))
+        # if pickle:
+        #     with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
+        #         cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap, 
+        #             'scores': scores, 'npos':npos}, f)
+
+    # Set thresh to mean for classes with poor results 
+    thresh = np.array(thresh)
+    avg_thresh = np.mean(thresh[thresh != 0])
+    thresh[thresh == 0] = avg_thresh
+    # if eval_attributes:
+    #     filename = 'attribute_thresholds_' + self._image_set + '.txt'
+    # else:
+    #     filename = 'object_thresholds_' + self._image_set + '.txt'
+    # path = os.path.join(output_dir, filename)       
+    # with open(path, 'wt') as f:
+    #     for i, cls in enumerate(classes[1:]):
+    #         f.write('{:s} {:.3f}\n'.format(cls, thresh[i]))           
+
+    weights = np.array(nposs)
+    weights /= weights.sum()
+    # print('Mean AP = {:.4f}'.format(np.mean(aps)))
+    # print('Weighted Mean AP = {:.4f}'.format(np.average(aps, weights=weights)))
+    # print('Mean Detection Threshold = {:.3f}'.format(avg_thresh))
+    # print('~~~~~~~~')
+    # print('Results:')
+    # for ap, npos in zip(aps, nposs):
+    #     print('{:.3f}\t{:.3f}'.format(ap, npos))
+    # print('{:.3f}'.format(np.mean(aps)))
+    # print('~~~~~~~~')
+    # print('')
+    # print('--------------------------------------------------------------')
+    # print('Results computed with the **unofficial** PASCAL VOC Python eval code.')
+    # print('--------------------------------------------------------------')
+
+    # pdb.set_trace()
+    return {"ap": aps, "map": np.mean(aps), "weighted map": np.average(aps, weights=weights)}
+
+
+def calc_detection_voc_prec_rec(pred_boxlists, gt_boxlists, classindex, iou_thresh=0.5, eval_attributes=False,
+                                use_07_metric=False):
+    """Calculate precision and recall based on evaluation code of PASCAL VOC.
+    This function calculates precision and recall of
+    predicted bounding boxes obtained from a dataset which has :math:`N`
+    images.
+    The code is based on the evaluation code used in PASCAL VOC Challenge.
+   """
+    class_recs = {}
+    npos = 0
+    image_ids = []
+    confidence = []
+    BB = []
+    for image_index, (gt_boxlist, pred_boxlist) in enumerate(zip(gt_boxlists, pred_boxlists)):
+        pred_bbox = pred_boxlist.bbox.numpy()
+        gt_bbox = gt_boxlist.bbox.numpy()
+        if eval_attributes:
+            gt_label = gt_boxlist.get_field("attributes").numpy()
+            pred_label = pred_boxlist.get_field("attr_labels").numpy()
+            pred_score = pred_boxlist.get_field("attr_scores").numpy()
+        else:
+            gt_label = gt_boxlist.get_field("labels").numpy()
+            pred_label = pred_boxlist.get_field("labels").numpy()
+            pred_score = pred_boxlist.get_field("scores").numpy()
+
+        # get the ground truth information for this class
+        if eval_attributes:
+            gt_mask_l = np.array([classindex in i for i in gt_label])
+        else:
+            gt_mask_l = gt_label == classindex
+        gt_bbox_l = gt_bbox[gt_mask_l]
+        gt_difficult_l = np.zeros(gt_bbox_l.shape[0], dtype=bool)
+        det = [False] * gt_bbox_l.shape[0]
+        npos = npos + sum(~gt_difficult_l)
+        class_recs[image_index] = {'bbox': gt_bbox_l,
+                                   'difficult': gt_difficult_l,
+                                   'det': det}
+
+        # prediction output for each class
+        # pdb.set_trace()
+        if eval_attributes:
+            pred_mask_l = np.logical_and(pred_label == classindex, np.not_equal(pred_score, 0.0)).nonzero()
+            pred_bbox_l = pred_bbox[pred_mask_l[0]]
+            pred_score_l = pred_score[pred_mask_l]
+        else:
+            pred_mask_l = pred_label == classindex
+            pred_bbox_l = pred_bbox[pred_mask_l]
+            pred_score_l = pred_score[pred_mask_l]
+
+        for bbox_tmp, score_tmp in zip(pred_bbox_l, pred_score_l):
+            image_ids.append(image_index)
+            confidence.append(float(score_tmp))
+            BB.append([float(z) for z in bbox_tmp])
+
+    if npos == 0:
+        # No ground truth examples
+        return 0, 0, 0, 0, npos
+
+    if len(confidence) == 0:
+        # No detection examples
+        return 0, 0, 0, 0, npos
+
+    confidence = np.array(confidence)
+    BB = np.array(BB)
+
+    # sort by confidence
+    sorted_ind = np.argsort(-confidence)
+    sorted_scores = -np.sort(-confidence)
+    BB = BB[sorted_ind, :]
+    image_ids = [image_ids[x] for x in sorted_ind]
+
+    # go down dets and mark TPs and FPs
+    nd = len(image_ids)
+    tp = np.zeros(nd)
+    fp = np.zeros(nd)
+
+    for d in range(nd):
+        R = class_recs[image_ids[d]]
+        bb = BB[d, :].astype(float)
+        ovmax = -np.inf
+        BBGT = R['bbox'].astype(float)
+
+        if BBGT.size > 0:
+            # compute overlaps
+            # intersection
+            ixmin = np.maximum(BBGT[:, 0], bb[0])
+            iymin = np.maximum(BBGT[:, 1], bb[1])
+            ixmax = np.minimum(BBGT[:, 2], bb[2])
+            iymax = np.minimum(BBGT[:, 3], bb[3])
+            iw = np.maximum(ixmax - ixmin + 1., 0.)
+            ih = np.maximum(iymax - iymin + 1., 0.)
+            inters = iw * ih
+
+            # union
+            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
+                   (BBGT[:, 2] - BBGT[:, 0] + 1.) *
+                   (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
+
+            overlaps = inters / uni
+            ovmax = np.max(overlaps)
+            jmax = np.argmax(overlaps)
+
+        if ovmax > iou_thresh:
+            if not R['difficult'][jmax]:
+                if not R['det'][jmax]:
+                    tp[d] = 1.
+                    R['det'][jmax] = 1
+                else:
+                    fp[d] = 1.
+        else:
+            fp[d] = 1.
+
+    # compute precision recall
+    fp = np.cumsum(fp)
+    tp = np.cumsum(tp)
+    rec = tp / float(npos)
+    # avoid divide by zero in case the first detection matches a difficult
+    # ground truth
+    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+    ap = voc_ap(rec, prec, use_07_metric)
+
+    return rec, prec, ap, sorted_scores, npos
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+    """ ap = voc_ap(rec, prec, [use_07_metric])
+    Compute VOC AP given precision and recall.
+    If use_07_metric is true, uses the
+    VOC 07 11 point method (default:False).
+    """
+    if use_07_metric:
+        # 11 point metric
+        ap = 0.
+        for t in np.arange(0., 1.1, 0.1):
+            if np.sum(rec >= t) == 0:
+                p = 0
+            else:
+                p = np.max(prec[rec >= t])
+            ap = ap + p / 11.
+    else:
+        # correct AP calculation
+        # first append sentinel values at the end
+        mrec = np.concatenate(([0.], rec, [1.]))
+        mpre = np.concatenate(([0.], prec, [0.]))
+
+        # compute the precision envelope
+        for i in range(mpre.size - 1, 0, -1):
+            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+        # to calculate area under PR curve, look for points
+        # where X axis (recall) changes value
+        i = np.where(mrec[1:] != mrec[:-1])[0]
+
+        # and sum (\Delta recall) * prec
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+    return ap
+
+
+def calc_detection_voc_ap(prec, rec, use_07_metric=False):
+    """Calculate average precisions based on evaluation code of PASCAL VOC.
+    This function calculates average precisions
+    from given precisions and recalls.
+    The code is based on the evaluation code used in PASCAL VOC Challenge.
+    Args:
+        prec (list of numpy.array): A list of arrays.
+            :obj:`prec[l]` indicates precision for class :math:`l`.
+            If :obj:`prec[l]` is :obj:`None`, this function returns
+            :obj:`numpy.nan` for class :math:`l`.
+        rec (list of numpy.array): A list of arrays.
+            :obj:`rec[l]` indicates recall for class :math:`l`.
+            If :obj:`rec[l]` is :obj:`None`, this function returns
+            :obj:`numpy.nan` for class :math:`l`.
+        use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric
+            for calculating average precision. The default value is
+            :obj:`False`.
+    Returns:
+        ~numpy.ndarray:
+        This function returns an array of average precisions.
+        The :math:`l`-th value corresponds to the average precision
+        for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is
+        :obj:`None`, the corresponding value is set to :obj:`numpy.nan`.
+    """
+
+    n_fg_class = len(prec)
+    ap = np.empty(n_fg_class)
+    for l in range(n_fg_class):
+        if prec[l] is None or rec[l] is None:
+            ap[l] = np.nan
+            continue
+
+        if use_07_metric:
+            # 11 point metric
+            ap[l] = 0
+            for t in np.arange(0.0, 1.1, 0.1):
+                if np.sum(rec[l] >= t) == 0:
+                    p = 0
+                else:
+                    p = np.max(np.nan_to_num(prec[l])[rec[l] >= t])
+                ap[l] += p / 11
+        else:
+            # correct AP calculation
+            # first append sentinel values at the end
+            mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0]))
+            mrec = np.concatenate(([0], rec[l], [1]))
+
+            mpre = np.maximum.accumulate(mpre[::-1])[::-1]
+
+            # to calculate area under PR curve, look for points
+            # where X axis (recall) changes value
+            i = np.where(mrec[1:] != mrec[:-1])[0]
+
+            # and sum (\Delta recall) * prec
+            ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+
+    return ap
+
+
+# inspired from Detectron
+def evaluate_box_proposals_for_relation(
+        predictions, dataset, thresholds=None, area="all", limit=None
+):
+    """Evaluate how many relation pairs can be captured by the proposed boxes.
+    """
+    # Record max overlap value for each gt box
+    # Return vector of overlap values
+    areas = {
+        "all": 0,
+        "small": 1,
+        "medium": 2,
+        "large": 3,
+        "96-128": 4,
+        "128-256": 5,
+        "256-512": 6,
+        "512-inf": 7,
+    }
+    area_ranges = [
+        [0 ** 2, 1e5 ** 2],  # all
+        [0 ** 2, 32 ** 2],  # small
+        [32 ** 2, 96 ** 2],  # medium
+        [96 ** 2, 1e5 ** 2],  # large
+        [96 ** 2, 128 ** 2],  # 96-128
+        [128 ** 2, 256 ** 2],  # 128-256
+        [256 ** 2, 512 ** 2],  # 256-512
+        [512 ** 2, 1e5 ** 2],
+    ]  # 512-inf
+    assert area in areas, "Unknown area range: {}".format(area)
+    area_range = area_ranges[areas[area]]
+    gt_overlaps = []
+    num_pos = 0
+
+    for image_id, prediction in enumerate(predictions):
+        img_info = dataset.get_img_info(image_id)
+        image_width = img_info["width"]
+        image_height = img_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+
+        # deal with ground truth
+        gt_boxes = dataset.get_groundtruth(image_id)
+        # filter out the field "relation_labels"
+        gt_triplets = gt_boxes.get_field("relation_labels")
+        if len(gt_triplets) == 0:
+            continue
+        gt_boxes = gt_boxes.copy_with_fields(['attributes', 'labels'])
+        # get union bounding boxes (the box that cover both)
+        gt_relations = getUnionBBox(gt_boxes[gt_triplets[:, 0]], gt_boxes[gt_triplets[:, 1]], margin=0)
+        gt_relations.add_field('rel_classes', gt_triplets[:, 2])
+        # focus on the range interested
+        gt_relation_areas = gt_relations.area()
+        valid_gt_inds = (gt_relation_areas >= area_range[0]) & (gt_relation_areas <= area_range[1])
+        gt_relations = gt_relations[valid_gt_inds]
+
+        num_pos += len(gt_relations)
+
+        if len(gt_relations) == 0:
+            continue
+
+        # sort predictions in descending order and limit to the number we specify
+        # TODO maybe remove this and make it explicit in the documentation
+        _gt_overlaps = torch.zeros(len(gt_relations))
+        if len(prediction) == 0:
+            gt_overlaps.append(_gt_overlaps)
+            continue
+        if "objectness" in prediction.extra_fields:
+            inds = prediction.get_field("objectness").sort(descending=True)[1]
+        elif "scores" in prediction.extra_fields:
+            inds = prediction.get_field("scores").sort(descending=True)[1]
+        else:
+            raise ValueError("Neither objectness nor scores is in the extra_fields!")
+        prediction = prediction[inds]
+        if limit is not None and len(prediction) > limit:
+            prediction = prediction[:limit]
+        # get the predicted relation pairs
+        N = len(prediction)
+        map_x = np.arange(N)
+        map_y = np.arange(N)
+        map_x_g, map_y_g = np.meshgrid(map_x, map_y)
+        anchor_pairs = torch.from_numpy(np.vstack((map_y_g.ravel(), map_x_g.ravel())).transpose())
+        # remove diagonal pairs
+        keep = anchor_pairs[:, 0] != anchor_pairs[:, 1]
+        anchor_pairs = anchor_pairs[keep]
+        # get anchor_relations
+        # anchor_relations = getUnionBBox(prediction[anchor_pairs[:,0]], prediction[anchor_pairs[:,1]], margin=0)
+        if len(anchor_pairs) == 0:
+            continue
+
+        overlaps_sub = boxlist_iou(prediction[anchor_pairs[:, 0]], gt_boxes[gt_triplets[valid_gt_inds, 0]])
+        overlaps_obj = boxlist_iou(prediction[anchor_pairs[:, 1]], gt_boxes[gt_triplets[valid_gt_inds, 1]])
+        overlaps = torch.min(overlaps_sub, overlaps_obj)
+
+        for j in range(min(len(anchor_pairs), len(gt_relations))):
+            # find which proposal box maximally covers each gt box
+            # and get the iou amount of coverage for each gt box
+            max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+            # find which gt box is 'best' covered (i.e. 'best' = most iou)
+            gt_ovr, gt_ind = max_overlaps.max(dim=0)
+            assert gt_ovr >= 0
+            # find the proposal pair that covers the best covered gt pair
+            pair_ind = argmax_overlaps[gt_ind]
+            # record the co-iou coverage of this gt pair
+            _gt_overlaps[j] = overlaps[pair_ind, gt_ind]
+            assert _gt_overlaps[j] == gt_ovr
+            # mark the proposal pair and the gt pair as used
+            overlaps[pair_ind, :] = -1
+            overlaps[:, gt_ind] = -1
+
+        # append recorded iou coverage level
+        gt_overlaps.append(_gt_overlaps)
+    gt_overlaps = torch.cat(gt_overlaps, dim=0)
+    gt_overlaps, _ = torch.sort(gt_overlaps)
+
+    if thresholds is None:
+        step = 0.05
+        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+    recalls = torch.zeros_like(thresholds)
+    # compute recall for each iou threshold
+    for i, t in enumerate(thresholds):
+        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+    # ar = 2 * np.trapz(recalls, thresholds)
+    ar = recalls.mean()
+    return {
+        "ar": ar,
+        "recalls": recalls,
+        "thresholds": thresholds,
+        "gt_overlaps": gt_overlaps,
+        "num_pos": num_pos,
+    }
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c26048b361ddd41b6e82d4bb9d5ead745f6bb07
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py
@@ -0,0 +1,16 @@
+import logging
+
+from .voc_eval import do_voc_evaluation
+
+
+def voc_evaluation(dataset, predictions, output_folder, box_only, **_):
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    if box_only:
+        logger.warning("voc evaluation doesn't support box_only, ignored.")
+    logger.info("performing voc evaluation, ignored iou_types.")
+    return do_voc_evaluation(
+        dataset=dataset,
+        predictions=predictions,
+        output_folder=output_folder,
+        logger=logger,
+    )
diff --git a/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac54d768d458861ca994dc7de1fa37f166ed012b
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py
@@ -0,0 +1,216 @@
+# A modification version from chainercv repository.
+# (See https://github.com/chainer/chainercv/blob/master/chainercv/evaluations/eval_detection_voc.py)
+from __future__ import division
+
+import os
+from collections import defaultdict
+import numpy as np
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+
+
+def do_voc_evaluation(dataset, predictions, output_folder, logger):
+    # TODO need to make the use_07_metric format available
+    # for the user to choose
+    pred_boxlists = []
+    gt_boxlists = []
+    for image_id, prediction in enumerate(predictions):
+        img_info = dataset.get_img_info(image_id)
+        if len(prediction) == 0:
+            continue
+        image_width = img_info["width"]
+        image_height = img_info["height"]
+        prediction = prediction.resize((image_width, image_height))
+        pred_boxlists.append(prediction)
+
+        gt_boxlist = dataset.get_groundtruth(image_id)
+        gt_boxlists.append(gt_boxlist)
+    result = eval_detection_voc(
+        pred_boxlists=pred_boxlists,
+        gt_boxlists=gt_boxlists,
+        iou_thresh=0.5,
+        use_07_metric=True,
+    )
+    result_str = "mAP: {:.4f}\n".format(result["map"])
+    for i, ap in enumerate(result["ap"]):
+        if i == 0:  # skip background
+            continue
+        result_str += "{:<16}: {:.4f}\n".format(
+            dataset.map_class_id_to_class_name(i), ap
+        )
+    logger.info(result_str)
+    if output_folder:
+        with open(os.path.join(output_folder, "result.txt"), "w") as fid:
+            fid.write(result_str)
+    return result
+
+
+def eval_detection_voc(pred_boxlists, gt_boxlists, iou_thresh=0.5, use_07_metric=False):
+    """Evaluate on voc dataset.
+    Args:
+        pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields.
+        gt_boxlists(list[BoxList]): ground truth boxlist, has labels field.
+        iou_thresh: iou thresh
+        use_07_metric: boolean
+    Returns:
+        dict represents the results
+    """
+    assert len(gt_boxlists) == len(
+        pred_boxlists
+    ), "Length of gt and pred lists need to be same."
+    prec, rec = calc_detection_voc_prec_rec(
+        pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, iou_thresh=iou_thresh
+    )
+    ap = calc_detection_voc_ap(prec, rec, use_07_metric=use_07_metric)
+    return {"ap": ap, "map": np.nanmean(ap)}
+
+
+def calc_detection_voc_prec_rec(gt_boxlists, pred_boxlists, iou_thresh=0.5):
+    """Calculate precision and recall based on evaluation code of PASCAL VOC.
+    This function calculates precision and recall of
+    predicted bounding boxes obtained from a dataset which has :math:`N`
+    images.
+    The code is based on the evaluation code used in PASCAL VOC Challenge.
+   """
+    n_pos = defaultdict(int)
+    score = defaultdict(list)
+    match = defaultdict(list)
+    for gt_boxlist, pred_boxlist in zip(gt_boxlists, pred_boxlists):
+        pred_bbox = pred_boxlist.bbox.numpy()
+        pred_label = pred_boxlist.get_field("labels").numpy()
+        pred_score = pred_boxlist.get_field("scores").numpy()
+        gt_bbox = gt_boxlist.bbox.numpy()
+        gt_label = gt_boxlist.get_field("labels").numpy()
+        gt_difficult = gt_boxlist.get_field("difficult").numpy()
+
+        for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)):
+            pred_mask_l = pred_label == l
+            pred_bbox_l = pred_bbox[pred_mask_l]
+            pred_score_l = pred_score[pred_mask_l]
+            # sort by score
+            order = pred_score_l.argsort()[::-1]
+            pred_bbox_l = pred_bbox_l[order]
+            pred_score_l = pred_score_l[order]
+
+            gt_mask_l = gt_label == l
+            gt_bbox_l = gt_bbox[gt_mask_l]
+            gt_difficult_l = gt_difficult[gt_mask_l]
+
+            n_pos[l] += np.logical_not(gt_difficult_l).sum()
+            score[l].extend(pred_score_l)
+
+            if len(pred_bbox_l) == 0:
+                continue
+            if len(gt_bbox_l) == 0:
+                match[l].extend((0,) * pred_bbox_l.shape[0])
+                continue
+
+            # VOC evaluation follows integer typed bounding boxes.
+            pred_bbox_l = pred_bbox_l.copy()
+            pred_bbox_l[:, 2:] += 1
+            gt_bbox_l = gt_bbox_l.copy()
+            gt_bbox_l[:, 2:] += 1
+            iou = boxlist_iou(
+                BoxList(pred_bbox_l, gt_boxlist.size),
+                BoxList(gt_bbox_l, gt_boxlist.size),
+            ).numpy()
+            gt_index = iou.argmax(axis=1)
+            # set -1 if there is no matching ground truth
+            gt_index[iou.max(axis=1) < iou_thresh] = -1
+            del iou
+
+            selec = np.zeros(gt_bbox_l.shape[0], dtype=bool)
+            for gt_idx in gt_index:
+                if gt_idx >= 0:
+                    if gt_difficult_l[gt_idx]:
+                        match[l].append(-1)
+                    else:
+                        if not selec[gt_idx]:
+                            match[l].append(1)
+                        else:
+                            match[l].append(0)
+                    selec[gt_idx] = True
+                else:
+                    match[l].append(0)
+
+    n_fg_class = max(n_pos.keys()) + 1
+    prec = [None] * n_fg_class
+    rec = [None] * n_fg_class
+
+    for l in n_pos.keys():
+        score_l = np.array(score[l])
+        match_l = np.array(match[l], dtype=np.int8)
+
+        order = score_l.argsort()[::-1]
+        match_l = match_l[order]
+
+        tp = np.cumsum(match_l == 1)
+        fp = np.cumsum(match_l == 0)
+
+        # If an element of fp + tp is 0,
+        # the corresponding element of prec[l] is nan.
+        prec[l] = tp / (fp + tp)
+        # If n_pos[l] is 0, rec[l] is None.
+        if n_pos[l] > 0:
+            rec[l] = tp / n_pos[l]
+
+    return prec, rec
+
+
+def calc_detection_voc_ap(prec, rec, use_07_metric=False):
+    """Calculate average precisions based on evaluation code of PASCAL VOC.
+    This function calculates average precisions
+    from given precisions and recalls.
+    The code is based on the evaluation code used in PASCAL VOC Challenge.
+    Args:
+        prec (list of numpy.array): A list of arrays.
+            :obj:`prec[l]` indicates precision for class :math:`l`.
+            If :obj:`prec[l]` is :obj:`None`, this function returns
+            :obj:`numpy.nan` for class :math:`l`.
+        rec (list of numpy.array): A list of arrays.
+            :obj:`rec[l]` indicates recall for class :math:`l`.
+            If :obj:`rec[l]` is :obj:`None`, this function returns
+            :obj:`numpy.nan` for class :math:`l`.
+        use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric
+            for calculating average precision. The default value is
+            :obj:`False`.
+    Returns:
+        ~numpy.ndarray:
+        This function returns an array of average precisions.
+        The :math:`l`-th value corresponds to the average precision
+        for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is
+        :obj:`None`, the corresponding value is set to :obj:`numpy.nan`.
+    """
+
+    n_fg_class = len(prec)
+    ap = np.empty(n_fg_class)
+    for l in range(n_fg_class):
+        if prec[l] is None or rec[l] is None:
+            ap[l] = np.nan
+            continue
+
+        if use_07_metric:
+            # 11 point metric
+            ap[l] = 0
+            for t in np.arange(0.0, 1.1, 0.1):
+                if np.sum(rec[l] >= t) == 0:
+                    p = 0
+                else:
+                    p = np.max(np.nan_to_num(prec[l])[rec[l] >= t])
+                ap[l] += p / 11
+        else:
+            # correct AP calculation
+            # first append sentinel values at the end
+            mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0]))
+            mrec = np.concatenate(([0], rec[l], [1]))
+
+            mpre = np.maximum.accumulate(mpre[::-1])[::-1]
+
+            # to calculate area under PR curve, look for points
+            # where X axis (recall) changes value
+            i = np.where(mrec[1:] != mrec[:-1])[0]
+
+            # and sum (\Delta recall) * prec
+            ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+
+    return ap
diff --git a/maskrcnn_benchmark/data/datasets/flickr.py b/maskrcnn_benchmark/data/datasets/flickr.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe71a932182f0cb88385e990c7f0c22342ef5fbf
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/flickr.py
@@ -0,0 +1,8 @@
+import torch
+import torchvision
+import torch.utils.data as data
+from maskrcnn_benchmark.data.datasets.modulated_coco import ModulatedDataset
+
+
+class FlickrDataset(ModulatedDataset):
+    pass
diff --git a/maskrcnn_benchmark/data/datasets/gqa.py b/maskrcnn_benchmark/data/datasets/gqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..98d906cf9c9cb7e4d5d2ad17923398b25f11d9f6
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/gqa.py
@@ -0,0 +1,91 @@
+import json
+from pathlib import Path
+
+import torch
+import torchvision
+
+from .modulated_coco import ConvertCocoPolysToMask, ModulatedDataset
+
+
+class GQADataset(ModulatedDataset):
+    pass
+
+
+class GQAQuestionAnswering(torchvision.datasets.CocoDetection):
+    def __init__(self, img_folder, ann_file, transforms, return_masks, return_tokens, tokenizer, ann_folder):
+        super(GQAQuestionAnswering, self).__init__(img_folder, ann_file)
+        self._transforms = transforms
+        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer)
+        with open(ann_folder / "gqa_answer2id.json", "r") as f:
+            self.answer2id = json.load(f)
+        with open(ann_folder / "gqa_answer2id_by_type.json", "r") as f:
+            self.answer2id_by_type = json.load(f)
+        self.type2id = {"obj": 0, "attr": 1, "rel": 2, "global": 3, "cat": 4}
+
+    def __getitem__(self, idx):
+        img, target = super(GQAQuestionAnswering, self).__getitem__(idx)
+        image_id = self.ids[idx]
+        coco_img = self.coco.loadImgs(image_id)[0]
+        caption = coco_img["caption"]
+        dataset_name = coco_img["dataset_name"]
+        questionId = coco_img["questionId"]
+        target = {"image_id": image_id, "annotations": target, "caption": caption}
+        img, target = self.prepare(img, target)
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+        target["dataset_name"] = dataset_name
+        target["questionId"] = questionId
+
+        if coco_img["answer"] not in self.answer2id:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+
+        target["answer"] = torch.as_tensor(self.answer2id[answer], dtype=torch.long)
+        target["answer_type"] = torch.as_tensor(self.type2id[coco_img["question_type"]], dtype=torch.long)
+
+        if coco_img["answer"] not in self.answer2id_by_type["answer_attr"]:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+        target["answer_attr"] = torch.as_tensor(
+            self.answer2id_by_type["answer_attr"][answer] if coco_img["question_type"] == "attr" else -100,
+            dtype=torch.long,
+        )
+
+        if coco_img["answer"] not in self.answer2id_by_type["answer_global"]:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+        target["answer_global"] = torch.as_tensor(
+            self.answer2id_by_type["answer_global"][answer] if coco_img["question_type"] == "global" else -100,
+            dtype=torch.long,
+        )
+
+        if coco_img["answer"] not in self.answer2id_by_type["answer_rel"]:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+        target["answer_rel"] = torch.as_tensor(
+            self.answer2id_by_type["answer_rel"][answer] if coco_img["question_type"] == "rel" else -100,
+            dtype=torch.long,
+        )
+
+        if coco_img["answer"] not in self.answer2id_by_type["answer_cat"]:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+        target["answer_cat"] = torch.as_tensor(
+            self.answer2id_by_type["answer_cat"][answer] if coco_img["question_type"] == "cat" else -100,
+            dtype=torch.long,
+        )
+
+        if coco_img["answer"] not in self.answer2id_by_type["answer_obj"]:
+            answer = "unknown"
+        else:
+            answer = coco_img["answer"]
+        target["answer_obj"] = torch.as_tensor(
+            self.answer2id_by_type["answer_obj"][answer] if coco_img["question_type"] == "obj" else -100,
+            dtype=torch.long,
+        )
+        return img, target
diff --git a/maskrcnn_benchmark/data/datasets/imagenet.py b/maskrcnn_benchmark/data/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..723ea7dcc89fc3cb2bc68664e3ede90a0083b3b3
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/imagenet.py
@@ -0,0 +1,63 @@
+import os
+import os.path
+import json
+from PIL import Image
+
+import torch.utils.data as data
+
+def pil_loader(path):
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, 'rb') as f:
+        img = Image.open(f)
+        return img.convert('RGB')
+
+class ImageNet(data.Dataset):
+    """ ImageNet
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+    """
+
+    def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None):
+
+
+        self.root = root
+        self.transform = transforms
+
+        meta_file = os.path.join(root, ann_file)
+        assert os.path.exists(meta_file), 'meta file %s under root %s not found' % (os.path.basename(meta_file), root)
+
+        with open(meta_file, 'r') as f:
+            meta = json.load(f)
+
+        self.classes = meta['classes']
+        self.class_to_idx = meta['class_to_idx']
+        self.samples = meta['samples']
+        self.num_sample = len(self.samples)
+        self.allsamples = self.samples
+
+    def select_class(self, cls):
+        new_samples = [sample for sample in self.allsamples if sample[-1] in cls]
+        self.samples = new_samples
+        self.num_sample = len(self.samples)
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        img_path, target = self.samples[index]
+        sample = pil_loader(self.root + '/' + img_path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+
+        return sample, target, index
+
+    def __len__(self):
+        return len(self.samples)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/list_dataset.py b/maskrcnn_benchmark/data/datasets/list_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a4f47fc08c8317ade1a762cf4070b6d16a3edf
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/list_dataset.py
@@ -0,0 +1,36 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Simple dataset class that wraps a list of path names
+"""
+
+from PIL import Image
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+
+class ListDataset(object):
+    def __init__(self, image_lists, transforms=None):
+        self.image_lists = image_lists
+        self.transforms = transforms
+
+    def __getitem__(self, item):
+        img = Image.open(self.image_lists[item]).convert("RGB")
+
+        # dummy target
+        w, h = img.size
+        target = BoxList([[0, 0, w, h]], img.size, mode="xyxy")
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.image_lists)
+
+    def get_img_info(self, item):
+        """
+        Return the image dimensions for the image, without
+        loading and pre-processing it
+        """
+        pass
diff --git a/maskrcnn_benchmark/data/datasets/lvis.py b/maskrcnn_benchmark/data/datasets/lvis.py
new file mode 100644
index 0000000000000000000000000000000000000000..753bcbc836a855f967403b78f4e843a86ce77e39
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/lvis.py
@@ -0,0 +1,268 @@
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import json
+import os
+import time
+from collections import defaultdict
+
+import pycocotools.mask as mask_utils
+import torchvision
+from PIL import Image
+
+# from .coco import ConvertCocoPolysToMask, make_coco_transforms
+from .modulated_coco import ConvertCocoPolysToMask
+
+
+def _isArrayLike(obj):
+    return hasattr(obj, "__iter__") and hasattr(obj, "__len__")
+
+
+class LVIS:
+    def __init__(self, annotation_path=None):
+        """Class for reading and visualizing annotations.
+        Args:
+            annotation_path (str): location of annotation file
+        """
+        self.anns = {}
+        self.cats = {}
+        self.imgs = {}
+        self.img_ann_map = defaultdict(list)
+        self.cat_img_map = defaultdict(list)
+        self.dataset = {}
+
+        if annotation_path is not None:
+            print("Loading annotations.")
+
+            tic = time.time()
+            self.dataset = self._load_json(annotation_path)
+            print("Done (t={:0.2f}s)".format(time.time() - tic))
+
+            assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset))
+            self._create_index()
+
+    def _load_json(self, path):
+        with open(path, "r") as f:
+            return json.load(f)
+
+    def _create_index(self):
+        print("Creating index.")
+
+        self.img_ann_map = defaultdict(list)
+        self.cat_img_map = defaultdict(list)
+
+        self.anns = {}
+        self.cats = {}
+        self.imgs = {}
+
+        for ann in self.dataset["annotations"]:
+            self.img_ann_map[ann["image_id"]].append(ann)
+            self.anns[ann["id"]] = ann
+
+        for img in self.dataset["images"]:
+            self.imgs[img["id"]] = img
+
+        for cat in self.dataset["categories"]:
+            self.cats[cat["id"]] = cat
+
+        for ann in self.dataset["annotations"]:
+            self.cat_img_map[ann["category_id"]].append(ann["image_id"])
+
+        print("Index created.")
+
+    def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None):
+        """Get ann ids that satisfy given filter conditions.
+        Args:
+            img_ids (int array): get anns for given imgs
+            cat_ids (int array): get anns for given cats
+            area_rng (float array): get anns for a given area range. e.g [0, inf]
+        Returns:
+            ids (int array): integer array of ann ids
+        """
+        if img_ids is not None:
+            img_ids = img_ids if _isArrayLike(img_ids) else [img_ids]
+        if cat_ids is not None:
+            cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids]
+        anns = []
+        if img_ids is not None:
+            for img_id in img_ids:
+                anns.extend(self.img_ann_map[img_id])
+        else:
+            anns = self.dataset["annotations"]
+
+        # return early if no more filtering required
+        if cat_ids is None and area_rng is None:
+            return [_ann["id"] for _ann in anns]
+
+        cat_ids = set(cat_ids)
+
+        if area_rng is None:
+            area_rng = [0, float("inf")]
+
+        ann_ids = [
+            _ann["id"]
+            for _ann in anns
+            if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1]
+        ]
+        return ann_ids
+
+    def get_cat_ids(self):
+        """Get all category ids.
+        Returns:
+            ids (int array): integer array of category ids
+        """
+        return list(self.cats.keys())
+
+    def get_img_ids(self):
+        """Get all img ids.
+        Returns:
+            ids (int array): integer array of image ids
+        """
+        return list(self.imgs.keys())
+
+    def _load_helper(self, _dict, ids):
+        if ids is None:
+            return list(_dict.values())
+        elif _isArrayLike(ids):
+            return [_dict[id] for id in ids]
+        else:
+            return [_dict[ids]]
+
+    def load_anns(self, ids=None):
+        """Load anns with the specified ids. If ids=None load all anns.
+        Args:
+            ids (int array): integer array of annotation ids
+        Returns:
+            anns (dict array) : loaded annotation objects
+        """
+        return self._load_helper(self.anns, ids)
+
+    def load_cats(self, ids):
+        """Load categories with the specified ids. If ids=None load all
+        categories.
+        Args:
+            ids (int array): integer array of category ids
+        Returns:
+            cats (dict array) : loaded category dicts
+        """
+        return self._load_helper(self.cats, ids)
+
+    def load_imgs(self, ids):
+        """Load categories with the specified ids. If ids=None load all images.
+        Args:
+            ids (int array): integer array of image ids
+        Returns:
+            imgs (dict array) : loaded image dicts
+        """
+        return self._load_helper(self.imgs, ids)
+
+    def download(self, save_dir, img_ids=None):
+        """Download images from mscoco.org server.
+        Args:
+            save_dir (str): dir to save downloaded images
+            img_ids (int array): img ids of images to download
+        """
+        imgs = self.load_imgs(img_ids)
+
+        if not os.path.exists(save_dir):
+            os.makedirs(save_dir)
+
+        for img in imgs:
+            file_name = os.path.join(save_dir, img["file_name"])
+            if not os.path.exists(file_name):
+                from urllib.request import urlretrieve
+
+                urlretrieve(img["coco_url"], file_name)
+
+    def ann_to_rle(self, ann):
+        """Convert annotation which can be polygons, uncompressed RLE to RLE.
+        Args:
+            ann (dict) : annotation object
+        Returns:
+            ann (rle)
+        """
+        img_data = self.imgs[ann["image_id"]]
+        h, w = img_data["height"], img_data["width"]
+        segm = ann["segmentation"]
+        if isinstance(segm, list):
+            # polygon -- a single object might consist of multiple parts
+            # we merge all parts into one mask rle code
+            rles = mask_utils.frPyObjects(segm, h, w)
+            rle = mask_utils.merge(rles)
+        elif isinstance(segm["counts"], list):
+            # uncompressed RLE
+            rle = mask_utils.frPyObjects(segm, h, w)
+        else:
+            # rle
+            rle = ann["segmentation"]
+        return rle
+
+    def ann_to_mask(self, ann):
+        """Convert annotation which can be polygons, uncompressed RLE, or RLE
+        to binary mask.
+        Args:
+            ann (dict) : annotation object
+        Returns:
+            binary mask (numpy 2D array)
+        """
+        rle = self.ann_to_rle(ann)
+        return mask_utils.decode(rle)
+
+
+class LvisDetectionBase(torchvision.datasets.VisionDataset):
+    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
+        super(LvisDetectionBase, self).__init__(root, transforms, transform, target_transform)
+        self.lvis = LVIS(annFile)
+        self.ids = list(sorted(self.lvis.imgs.keys()))
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        lvis = self.lvis
+        img_id = self.ids[index]
+        ann_ids = lvis.get_ann_ids(img_ids=img_id)
+        target = lvis.load_anns(ann_ids)
+
+        path = "/".join(self.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:])
+
+        img = Image.open(os.path.join(self.root, path)).convert("RGB")
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+    
+
+    def __len__(self):
+        return len(self.ids)
+
+
+class LvisDetection(LvisDetectionBase):
+    def __init__(self, img_folder, ann_file, transforms, return_masks=False, **kwargs):
+        super(LvisDetection, self).__init__(img_folder, ann_file)
+        self.ann_file = ann_file
+        self._transforms = transforms
+        self.prepare = ConvertCocoPolysToMask(return_masks)
+
+    def __getitem__(self, idx):
+        img, target = super(LvisDetection, self).__getitem__(idx)
+        image_id = self.ids[idx]
+        target = {"image_id": image_id, "annotations": target}
+        img, target = self.prepare(img, target)
+        if self._transforms is not None:
+            img = self._transforms(img)
+        return img, target, idx
+    
+    def get_raw_image(self, idx):
+        img, target = super(LvisDetection, self).__getitem__(idx)
+        return img
+    
+    def categories(self):
+        id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]}
+        all_cats = sorted(list(id2cat.keys()))
+        categories = {}
+        for l in list(all_cats):
+            categories[l] = id2cat[l]['name']
+        return categories
\ No newline at end of file
diff --git a/maskrcnn_benchmark/data/datasets/mixed.py b/maskrcnn_benchmark/data/datasets/mixed.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aec54451233173b3e9de107593c87feeb8a3691
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/mixed.py
@@ -0,0 +1,145 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+from PIL import Image, ImageDraw
+from torchvision.datasets.vision import VisionDataset
+
+from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation
+
+
+class CustomCocoDetection(VisionDataset):
+    """Coco-style dataset imported from TorchVision.
+        It is modified to handle several image sources
+
+    Args:
+        root_coco (string): Path to the coco images
+        root_vg (string): Path to the vg images
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    def __init__(
+            self,
+            root_coco: str,
+            root_vg: str,
+            annFile: str,
+            transform: Optional[Callable] = None,
+            target_transform: Optional[Callable] = None,
+            transforms: Optional[Callable] = None,
+    ) -> None:
+        super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+        ids = []
+        for img_id in self.ids:
+            if isinstance(img_id, str):
+                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+            else:
+                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+            anno = self.coco.loadAnns(ann_ids)
+            if has_valid_annotation(anno):
+                ids.append(img_id)
+        self.ids = ids
+
+        self.root_coco = root_coco
+        self.root_vg = root_vg
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        coco = self.coco
+        img_id = self.ids[index]
+        ann_ids = coco.getAnnIds(imgIds=img_id)
+        target = coco.loadAnns(ann_ids)
+
+        img_info = coco.loadImgs(img_id)[0]
+        path = img_info["file_name"]
+        dataset = img_info["data_source"]
+
+        cur_root = self.root_coco if dataset == "coco" else self.root_vg
+        img = Image.open(os.path.join(cur_root, path)).convert("RGB")
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.ids)
+
+
+class MixedDataset(CustomCocoDetection):
+    """Same as the modulated detection dataset, except with multiple img sources"""
+
+    def __init__(self,
+                 img_folder_coco,
+                 img_folder_vg,
+                 ann_file,
+                 transforms,
+                 return_masks,
+                 return_tokens,
+                 tokenizer=None,
+                 disable_clip_to_image=False,
+                 no_mask_for_gold=False,
+                 max_query_len=256,
+                 **kwargs):
+        super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file)
+        self._transforms = transforms
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
+        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
+        self.disable_clip_to_image = disable_clip_to_image
+        self.no_mask_for_gold = no_mask_for_gold
+
+    def __getitem__(self, idx):
+        img, target = super(MixedDataset, self).__getitem__(idx)
+
+        image_id = self.ids[idx]
+        caption = self.coco.loadImgs(image_id)[0]["caption"]
+        anno = {"image_id": image_id, "annotations": target, "caption": caption}
+        anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
+        if self.no_mask_for_gold:
+            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
+
+        img, anno = self.prepare(img, anno)
+
+        # convert to BoxList (bboxes, labels)
+        boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4)  # guard against no boxes
+        target = BoxList(boxes, img.size, mode="xyxy")
+        classes = anno["labels"]
+        target.add_field("labels", classes)
+        if not self.disable_clip_to_image:
+            num_boxes = len(boxes)
+            target = target.clip_to_image(remove_empty=True)
+            assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!"
+
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+
+        # add additional property
+        for ann in anno:
+            target.add_field(ann, anno[ann])
+
+        return img, target, idx
+
+    def get_img_info(self, index):
+        img_id = self.id_to_img_map[index]
+        img_data = self.coco.imgs[img_id]
+        return img_data
diff --git a/maskrcnn_benchmark/data/datasets/mixup.py b/maskrcnn_benchmark/data/datasets/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..110775727526137e5f9af7a85619f6e268b9cdbd
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/mixup.py
@@ -0,0 +1,124 @@
+"""Mixup detection dataset wrapper."""
+from __future__ import absolute_import
+import numpy as np
+import torch
+import torch.utils.data as data
+
+
+class MixupDetection(data.Dataset):
+    """Detection dataset wrapper that performs mixup for normal dataset.
+    Parameters
+    ----------
+    dataset : mx.gluon.data.Dataset
+        Gluon dataset object.
+    mixup : callable random generator, e.g. np.random.uniform
+        A random mixup ratio sampler, preferably a random generator from numpy.random
+        A random float will be sampled each time with mixup(*args).
+        Use None to disable.
+    *args : list
+        Additional arguments for mixup random sampler.
+    """
+    def __init__(self, dataset, mixup=None, preproc=None, *args):
+        super().__init__(dataset.input_dim)
+        self._dataset = dataset
+        self.preproc = preproc
+        self._mixup = mixup
+        self._mixup_args = args
+
+    def set_mixup(self, mixup=None, *args):
+        """Set mixup random sampler, use None to disable.
+        Parameters
+        ----------
+        mixup : callable random generator, e.g. np.random.uniform
+            A random mixup ratio sampler, preferably a random generator from numpy.random
+            A random float will be sampled each time with mixup(*args)
+        *args : list
+            Additional arguments for mixup random sampler.
+        """
+        self._mixup = mixup
+        self._mixup_args = args
+
+    def __len__(self):
+        return len(self._dataset)
+
+    @Dataset.resize_getitem
+    def __getitem__(self, idx):
+        self._dataset._input_dim = self.input_dim
+        # first image
+        img1, label1, _, _= self._dataset.pull_item(idx)
+        lambd = 1
+
+        # draw a random lambda ratio from distribution
+        if self._mixup is not None:
+            lambd = max(0, min(1, self._mixup(*self._mixup_args)))
+
+        if lambd >= 1:
+            weights1 = np.ones((label1.shape[0], 1))
+            label1 = np.hstack((label1, weights1))
+            height, width, _ = img1.shape
+            img_info = (width, height)
+            if self.preproc is not None:
+                img_o, target_o = self.preproc(img1, label1, self.input_dim)
+            return img_o, target_o, img_info, idx
+
+        # second image
+        idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx)))
+        img2, label2, _, _ = self._dataset.pull_item(idx2)
+
+        # mixup two images
+        height = max(img1.shape[0], img2.shape[0])
+        width = max(img1.shape[1], img2.shape[1])
+        mix_img = np.zeros((height, width, 3),dtype=np.float32)
+        mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype(np.float32) * lambd
+        mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype(np.float32) * (1. - lambd)
+        mix_img = mix_img.astype(np.uint8)
+
+        y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd)))
+        y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1. - lambd)))
+        mix_label = np.vstack((y1, y2))
+        if self.preproc is not None:
+            mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim)
+
+        img_info = (width, height)
+
+        return mix_img, padded_labels, img_info , idx
+
+    def pull_item(self, idx):
+        self._dataset._input_dim = self.input_dim
+        # first image
+        img1, label1, _, _= self._dataset.pull_item(idx)
+        lambd = 1
+
+        # draw a random lambda ratio from distribution
+        if self._mixup is not None:
+            lambd = max(0, min(1, self._mixup(*self._mixup_args)))
+
+        if lambd >= 1:
+            weights1 = np.ones((label1.shape[0], 1))
+            label1 = np.hstack((label1, weights1))
+            height, width, _ = img1.shape
+            img_info = (width, height)
+            if self.preproc is not None:
+                img_o, target_o = self.preproc(img1, label1, self.input_dim)
+            return img_o, target_o, img_info, idx
+
+        # second image
+        idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx)))
+        img2, label2 = self._dataset.pull_item(idx2)
+
+        # mixup two images
+        height = max(img1.shape[0], img2.shape[0])
+        width = max(img1.shape[1], img2.shape[1])
+        mix_img = np.zeros((height, width, 3),dtype=np.float32)
+        mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype(np.float32) * lambd
+        mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype(np.float32) * (1. - lambd)
+        mix_img = mix_img.astype(np.uint8)
+
+        y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd)))
+        y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1. - lambd)))
+        mix_label = np.vstack((y1, y2))
+        if self.preproc is not None:
+            mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim)
+
+        img_info = (width, height)
+        return mix_img, padded_labels, img_info , idx
diff --git a/maskrcnn_benchmark/data/datasets/modulated_coco.py b/maskrcnn_benchmark/data/datasets/modulated_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..23f6d3610a1231bb0ae0c99affe7374b9551df96
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/modulated_coco.py
@@ -0,0 +1,654 @@
+import logging
+import os
+import os.path
+import math
+from PIL import Image, ImageDraw
+
+import random
+import numpy as np
+
+import torch
+import torchvision
+import torch.utils.data as data
+from pycocotools import mask as coco_mask
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
+from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation
+from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od
+import pdb
+import json
+
+class CocoGrounding(torchvision.datasets.CocoDetection):
+    def __init__(self,
+                 img_folder,
+                 ann_file,
+                 transforms,
+                 return_masks,
+                 return_tokens,
+                 is_train=False,
+                 tokenizer=None,
+                 disable_shuffle=False,
+                 add_detection_prompt=False,
+                 one_hot=False,
+                 disable_clip_to_image=False,
+                 no_minus_one_for_one_hot=False,
+                 separation_tokens=" ",
+                 few_shot=0,
+                 no_mask_for_od=False,
+                 override_category=None,
+                 use_caption_prompt=False,
+                 caption_prompt=None,
+                 max_query_len=256,
+                 special_safeguard_for_coco_grounding=False,
+                 random_sample_negative=-1,
+                 **kwargs
+                 ):
+        super(CocoGrounding, self).__init__(img_folder, ann_file)
+        self.ids = sorted(self.ids)
+
+        ids = []
+        for img_id in self.ids:
+            if isinstance(img_id, str):
+                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+            else:
+                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+            anno = self.coco.loadAnns(ann_ids)
+            if has_valid_annotation(anno):
+                ids.append(img_id)
+
+        self.ids = ids
+        
+        if few_shot:
+            ids = []
+            # cats_freq = [few_shot]*len(self.coco.cats.keys())
+            cats_freq = [few_shot]*max(list(self.coco.cats.keys()))
+            for img_id in self.ids:
+                if isinstance(img_id, str):
+                    ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+                else:
+                    ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+                anno = self.coco.loadAnns(ann_ids)
+                cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level
+                is_needed = sum([cats_freq[c-1]>0 for c in cat])
+                if is_needed:
+                    ids.append(img_id)
+                    for c in cat:
+                        cats_freq[c-1] -= 1
+                    # print(cat, cats_freq)
+            self.ids = ids
+
+
+
+        self.json_category_id_to_contiguous_id = {
+            v: i + 1 for i, v in enumerate(self.coco.getCatIds())
+        }
+        self.contiguous_category_id_to_json_id = {
+            v: k for k, v in self.json_category_id_to_contiguous_id.items()
+        }
+
+        if override_category is not None:
+            self.coco.dataset["categories"] = override_category
+        self.use_caption_prompt = use_caption_prompt
+        self.caption_prompt = caption_prompt
+        self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding
+        self.random_sample_negative = random_sample_negative
+        self.ind_to_class = self.categories(no_background=False)
+        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
+        self._transforms = transforms
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
+        self.tokenizer = tokenizer
+        self.is_train = is_train
+
+        self.ind_to_class = self.categories(no_background=False)
+
+        self.disable_shuffle = disable_shuffle
+        self.add_detection_prompt = add_detection_prompt
+        self.one_hot = one_hot
+        self.no_minus_one_for_one_hot = no_minus_one_for_one_hot
+
+        self.disable_clip_to_image = disable_clip_to_image
+        self.separation_tokens = separation_tokens
+        self.no_mask_for_od = no_mask_for_od
+        self.return_masks = return_masks
+
+    def categories(self, no_background=True):
+        categories = self.coco.dataset["categories"]
+        label_list = {}
+        for index, i in enumerate(categories):
+            # assert(index + 1 == i["id"])
+            if not no_background or (i["name"] != "__background__" and i['id'] != 0):
+                label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"]
+        return label_list
+
+    def get_box_mask(self, rect, img_size, mode="poly"):
+        assert mode=="poly", "Only support poly mask right now!"
+        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
+        return [[x1, y1, x1, y2, x2, y2, x2, y1]]
+
+    def __getitem__(self, idx):
+        img, tgt = super(CocoGrounding, self).__getitem__(idx)
+        image_id = self.ids[idx]
+        tgt = [obj for obj in tgt if obj["iscrowd"] == 0]
+        boxes = [obj["bbox"] for obj in tgt]
+        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
+        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
+        classes = [obj["category_id"] for obj in tgt]
+        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
+        classes = torch.tensor(classes)
+        target.add_field("labels", classes)
+
+        if self.return_masks:
+            masks = []
+            is_box_mask = []
+            for obj, bbox in zip(tgt, target.bbox):
+                if "segmentation" in obj:
+                    masks.append(obj["segmentation"])
+                    is_box_mask.append(0)
+                else:
+                    masks.append(self.get_box_mask(bbox, img.size, mode="poly"))
+                    is_box_mask.append(1)
+            masks = SegmentationMask(masks, img.size, mode="poly")
+            is_box_mask = torch.tensor(is_box_mask)
+            target.add_field("masks", masks)
+            target.add_field("is_box_mask", is_box_mask)
+        
+        if not self.disable_clip_to_image:
+            target = target.clip_to_image(remove_empty=True)
+        
+        if self.special_safeguard_for_coco_grounding:
+            # Intended for LVIS
+            assert(not self.use_caption_prompt)
+
+            original_box_num = len(target)
+            target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens
+            if len(target) < original_box_num:
+                print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))
+
+            annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
+                target=target,
+                image_id=image_id,
+                ind_to_class=self.ind_to_class,
+                disable_shuffle=self.disable_shuffle,
+                add_detection_prompt=False,
+                add_detection_prompt_advanced=False,
+                random_sample_negative=self.random_sample_negative,
+                control_probabilities=(0.0, 0.0, 1.0, 0.0), # always try to add a lot of negatives
+                restricted_negative_list=None,
+                separation_tokens=self.separation_tokens,
+                max_num_labels=-1,
+                positive_caption_length=positive_caption_length,
+                tokenizer=self.tokenizer,
+                max_seq_length=self.max_query_len-2
+            )
+        else:
+            # Intended for COCO / ODinW
+            annotations, caption, greenlight_span_for_masked_lm_objective = convert_od_to_grounding_simple(
+                target=target,
+                image_id=image_id,
+                ind_to_class=self.ind_to_class,
+                disable_shuffle=self.disable_shuffle,
+                add_detection_prompt=self.add_detection_prompt,
+                separation_tokens=self.separation_tokens,
+                caption_prompt=self.caption_prompt if self.use_caption_prompt else None,
+            )
+
+        anno = {"image_id": image_id, "annotations": annotations, "caption": caption}
+        anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
+        if self.no_mask_for_od:
+            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
+        img, anno = self.prepare(img, anno, box_format="xyxy")
+
+        # for equivalence check
+        if self.one_hot:
+            logging.info("using one hot for equivalence check.")
+            one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float)
+            text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64)
+            # create one hot mapping
+            for ii, cls in enumerate(classes):
+                if self.no_minus_one_for_one_hot:
+                    one_hot_map[ii, cls] = 1.0
+                else:
+                    one_hot_map[ii, cls - 1] = 1.0
+            if self.no_minus_one_for_one_hot:
+                text_mask[:] = 1
+            else:
+                text_mask[:len(self.ind_to_class)] = 1
+            anno["positive_map"] = one_hot_map
+            anno["text_mask"] = text_mask
+
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+
+        # add additional property
+        for ann in anno:
+            target.add_field(ann, anno[ann])
+        
+        sanity_check_target_after_processing(target)
+
+        return img, target, idx
+
+    def get_img_info(self, index):
+        img_id = self.id_to_img_map[index]
+        img_data = self.coco.imgs[img_id]
+        return img_data
+
+
+class ModulatedDataset(torchvision.datasets.CocoDetection):
+    def __init__(self,
+                 img_folder,
+                 ann_file,
+                 transforms,
+                 return_masks,
+                 return_tokens,
+                 is_train=False,
+                 tokenizer=None,
+                 disable_clip_to_image=False,
+                 no_mask_for_gold=False,
+                 max_query_len=256,
+                 **kwargs):
+        super(ModulatedDataset, self).__init__(img_folder, ann_file)
+        self.ids = sorted(self.ids)
+
+        ids = []
+        for img_id in self.ids:
+            if isinstance(img_id, str):
+                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
+            else:
+                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+            anno = self.coco.loadAnns(ann_ids)
+            if has_valid_annotation(anno):
+                ids.append(img_id)
+        self.ids = ids
+
+        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
+        self._transforms = transforms
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
+        self.is_train = is_train
+        self.disable_clip_to_image = disable_clip_to_image
+        self.no_mask_for_gold = no_mask_for_gold
+
+    def __getitem__(self, idx):
+        img, target = super(ModulatedDataset, self).__getitem__(idx)
+        image_id = self.ids[idx]
+        coco_img = self.coco.loadImgs(image_id)[0]
+        caption = coco_img["caption"]
+        dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None
+        anno = {"image_id": image_id, "annotations": target, "caption": caption}
+
+        # This dataset is used for Flickr & Mixed, so the sequence is maskable
+        anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
+        if self.no_mask_for_gold:
+            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
+        img, anno = self.prepare(img, anno)
+
+        # convert to BoxList (bboxes, labels)
+        boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4)  # guard against no boxes
+        target = BoxList(boxes, img.size, mode="xyxy")
+        classes = anno["labels"]
+        target.add_field("labels", classes)
+        if self.prepare.return_masks:
+            target.add_field("masks", anno.pop("masks"))
+            target.add_field("is_box_mask", anno.pop("is_box_mask"))
+        if not self.disable_clip_to_image:
+            num_boxes = len(target.bbox)
+            target = target.clip_to_image(remove_empty=True)
+            assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!"
+
+        # Check if bboxes are correct
+        # draw = ImageDraw.Draw(img)
+        # boxes = target.bbox
+        # for box in boxes:
+        #     draw.rectangle([box[0], box[1], box[2], box[3]])
+        # img.save('OUTPUT/images/{}.jpg'.format(idx))
+
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+
+        # add additional property
+        for ann in anno:
+            target.add_field(ann, anno[ann])
+
+        target.add_field("dataset_name", dataset_name)
+        for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]:
+            if extra_key in coco_img:
+                target.add_field(extra_key, coco_img[extra_key])
+
+        if "tokens_positive_eval" in coco_img and not self.is_train:
+            tokenized = self.prepare.tokenizer(caption, return_tensors="pt")
+            target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"]))
+            target.add_field("nb_eval", len(target.get_field("positive_map_eval")))
+
+        sanity_check_target_after_processing(target)
+        return img, target, idx
+
+    def get_img_info(self, index):
+        img_id = self.id_to_img_map[index]
+        img_data = self.coco.imgs[img_id]
+        return img_data
+
+
+class CocoDetection(data.Dataset):
+    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(self, root, annFile, transform=None, target_transform=None):
+        from pycocotools.coco import COCO
+        self.root = root
+        self.coco = COCO(annFile)
+        self.ids = list(self.coco.imgs.keys())
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __getitem__(self, index, return_meta=False):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        coco = self.coco
+        img_id = self.ids[index]
+        if isinstance(img_id, str):
+            img_id = [img_id]
+        ann_ids = coco.getAnnIds(imgIds=img_id)
+        target = coco.loadAnns(ann_ids)
+
+        meta = coco.loadImgs(img_id)[0]
+        path = meta['file_name']
+        img = pil_loader(os.path.join(self.root, path))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        if return_meta:
+            return img, target, meta
+        else:
+            return img, target
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __repr__(self):
+        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
+        fmt_str += '    Root Location: {}\n'.format(self.root)
+        tmp = '    Transforms (if any): '
+        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        tmp = '    Target Transforms (if any): '
+        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        return fmt_str
+
+
+class ConvertCocoPolysToMask(object):
+    def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256):
+        self.return_masks = return_masks
+        self.return_tokens = return_tokens
+        self.tokenizer = tokenizer
+        self.max_query_len = max_query_len
+
+    def get_box_mask(self, rect, img_size, mode="poly"):
+        assert mode=="poly", "Only support poly mask right now!"
+        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
+        return [[x1, y1, x1, y2, x2, y2, x2, y1]]
+
+    def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"):
+        w, h = image.size
+
+        image_id = target["image_id"]
+        image_id = torch.tensor([image_id])
+
+        anno = target["annotations"]
+        caption = target["caption"] if "caption" in target else None
+        label_to_positions = target.get("label_to_positions", {})
+
+        greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None)
+
+        anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
+
+        boxes = [obj["bbox"] for obj in anno]
+        # guard against no boxes via resizing
+        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+        if box_format == "xywh":
+            boxes[:, 2:] += boxes[:, :2] - 1  # TO_REMOVE = 1
+            boxes[:, 0::2].clamp_(min=0, max=w-1)  # TO_REMOVE = 1
+            boxes[:, 1::2].clamp_(min=0, max=h-1)  # TO_REMOVE = 1
+
+        classes = [obj["category_id"] for obj in anno]
+        classes = torch.tensor(classes, dtype=torch.int64)
+
+        if self.return_masks:
+            masks = []
+            is_box_mask = []
+            for obj, bbox in zip(anno, boxes):
+                if "segmentation" in obj:
+                    masks.append(obj["segmentation"])
+                    is_box_mask.append(0)
+                else:
+                    masks.append(self.get_box_mask(bbox, image.size, mode='poly'))
+                    is_box_mask.append(1)
+            masks = SegmentationMask(masks, image.size, mode='poly')
+            is_box_mask = torch.tensor(is_box_mask)
+
+        keypoints = None
+        if anno and "keypoints" in anno[0]:
+            keypoints = [obj["keypoints"] for obj in anno]
+            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
+            num_keypoints = keypoints.shape[0]
+            if num_keypoints:
+                keypoints = keypoints.view(num_keypoints, -1, 3)
+
+        isfinal = None
+        if anno and "isfinal" in anno[0]:
+            isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float)
+
+        tokens_positive = [] if self.return_tokens else None
+        if self.return_tokens and anno and "tokens" in anno[0]:
+            tokens_positive = [obj["tokens"] for obj in anno]
+        elif self.return_tokens and anno and "tokens_positive" in anno[0]:
+            tokens_positive = [obj["tokens_positive"] for obj in anno]
+
+        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+        boxes = boxes[keep]
+        classes = classes[keep]
+        if self.return_masks:
+            masks = masks[keep]
+            is_box_mask = is_box_mask[keep]
+        if keypoints is not None:
+            keypoints = keypoints[keep]
+
+        target = {}
+        target["boxes"] = boxes
+        target["labels"] = classes
+        if caption is not None:
+            target["caption"] = caption
+        if self.return_masks:
+            target["masks"] = masks
+            target["is_box_mask"] = is_box_mask
+        target["image_id"] = image_id
+        if keypoints is not None:
+            target["keypoints"] = keypoints
+
+        if tokens_positive is not None:
+            target["tokens_positive"] = []
+
+            for i, k in enumerate(keep):
+                if k or ignore_box_screen:
+                    target["tokens_positive"].append(tokens_positive[i])
+
+        if isfinal is not None:
+            target["isfinal"] = isfinal
+
+        # for conversion to coco api
+        area = torch.tensor([obj["area"] for obj in anno])
+        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
+        target["area"] = area[keep]
+        target["iscrowd"] = iscrowd[keep]
+
+        target["orig_size"] = torch.as_tensor([int(h), int(w)])
+        target["size"] = torch.as_tensor([int(h), int(w)])
+
+        if self.return_tokens and self.tokenizer is not None:
+            if not ignore_box_screen:
+                assert len(target["boxes"]) == len(target["tokens_positive"])
+            tokenized = self.tokenizer(caption, return_tensors="pt",
+                max_length=self.max_query_len,
+                truncation=True)
+            target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"])
+            target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized)
+            target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)
+
+        original_od_label = []
+        for obj in anno:
+            original_od_label.append(
+                obj.get("original_od_label", -10))  # NOTE: The padding value has to be not the same as -1 or -100
+        target["original_od_label"] = torch.as_tensor(original_od_label)
+
+        return image, target
+
+def create_greenlight_map(tok_list, tokenized):
+    # An example tok_list:
+    # [(0, 5), (10, 13), (-1, -1, -1)]
+    # The last one is a special indicator..
+
+    greenlight_map = torch.zeros(256, dtype=torch.float)
+    for item in tok_list:
+        if len(item) != 2:
+            assert(len(item) == 3)
+            # Make everything unmakable
+            greenlight_map[:] = -1
+            break
+
+        beg, end = item
+        beg_pos = tokenized.char_to_token(beg)
+        end_pos = tokenized.char_to_token(end - 1)
+        if beg_pos is None:
+            try:
+                beg_pos = tokenized.char_to_token(beg + 1)
+                if beg_pos is None:
+                    beg_pos = tokenized.char_to_token(beg + 2)
+            except:
+                beg_pos = None
+        if end_pos is None:
+            try:
+                end_pos = tokenized.char_to_token(end - 2)
+                if end_pos is None:
+                    end_pos = tokenized.char_to_token(end - 3)
+            except:
+                end_pos = None
+        if beg_pos is None or end_pos is None:
+            continue
+
+        assert beg_pos is not None and end_pos is not None
+        greenlight_map[beg_pos: end_pos + 1].fill_(1)
+    return greenlight_map
+
+
+def create_positive_map_for_od_labels(tokenized, label_to_positions):
+    """construct a map such that positive_map[i] = j, where j is the object detection label of the token i"""
+    """
+    {3: [1: 5)}
+    256 : -1 3 3 3 3 -1 .. 8 8 ..
+    the woman in the garden
+    -1 -1 -1 -1 -1
+    """
+    positive_map = torch.ones(256, dtype=torch.float) * -1  # -1 means no match
+    keys = list(label_to_positions.keys())
+    for j, key in enumerate(keys):
+        tok_list = label_to_positions[key]
+        # one label only mapps to one location
+        beg, end = tok_list
+        beg_pos = tokenized.char_to_token(beg)
+        end_pos = tokenized.char_to_token(end - 1)
+        if beg_pos is None:
+            try:
+                beg_pos = tokenized.char_to_token(beg + 1)
+                if beg_pos is None:
+                    beg_pos = tokenized.char_to_token(beg + 2)
+            except:
+                beg_pos = None
+        if end_pos is None:
+            try:
+                end_pos = tokenized.char_to_token(end - 2)
+                if end_pos is None:
+                    end_pos = tokenized.char_to_token(end - 3)
+            except:
+                end_pos = None
+        if beg_pos is None or end_pos is None:
+            continue
+        assert beg_pos is not None and end_pos is not None
+        positive_map[beg_pos: end_pos + 1].fill_(key)
+    return positive_map
+
+
+def convert_coco_poly_to_mask(segmentations, height, width):
+    masks = []
+    for polygons in segmentations:
+        rles = coco_mask.frPyObjects(polygons, height, width)
+        mask = coco_mask.decode(rles)
+        if len(mask.shape) < 3:
+            mask = mask[..., None]
+        mask = torch.as_tensor(mask, dtype=torch.uint8)
+        mask = mask.any(dim=2)
+        masks.append(mask)
+    if masks:
+        masks = torch.stack(masks, dim=0)
+    else:
+        masks = torch.zeros((0, height, width), dtype=torch.uint8)
+    return masks
+
+
+def create_positive_map(tokenized, tokens_positive):
+    """construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
+    positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
+
+    for j, tok_list in enumerate(tokens_positive):
+        for (beg, end) in tok_list:
+            beg_pos = tokenized.char_to_token(beg)
+            end_pos = tokenized.char_to_token(end - 1)
+            if beg_pos is None:
+                try:
+                    beg_pos = tokenized.char_to_token(beg + 1)
+                    if beg_pos is None:
+                        beg_pos = tokenized.char_to_token(beg + 2)
+                except:
+                    beg_pos = None
+            if end_pos is None:
+                try:
+                    end_pos = tokenized.char_to_token(end - 2)
+                    if end_pos is None:
+                        end_pos = tokenized.char_to_token(end - 3)
+                except:
+                    end_pos = None
+            if beg_pos is None or end_pos is None:
+                continue
+
+            assert beg_pos is not None and end_pos is not None
+            positive_map[j, beg_pos: end_pos + 1].fill_(1)
+    return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
+
+
+def pil_loader(path, retry=5):
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    ri = 0
+    while ri < retry:
+        try:
+            with open(path, 'rb') as f:
+                img = Image.open(f)
+                return img.convert('RGB')
+        except:
+            ri += 1
diff --git a/maskrcnn_benchmark/data/datasets/object365.py b/maskrcnn_benchmark/data/datasets/object365.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa9bb4aabe13237b9fad229b310be8b50e31727b
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/object365.py
@@ -0,0 +1,8 @@
+import torch
+import torchvision
+import torch.utils.data as data
+from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV
+
+
+class Object365DetectionTSV(CocoDetectionTSV):
+    pass
diff --git a/maskrcnn_benchmark/data/datasets/od_to_grounding.py b/maskrcnn_benchmark/data/datasets/od_to_grounding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b93aace9ea6e08a0ae8e7d1b8f87729dfcd84bbc
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/od_to_grounding.py
@@ -0,0 +1,375 @@
+import numpy as np
+import random
+import re
+import torch
+import pdb
+import logging
+
+
+def clean_name(name):
+    name = re.sub(r"\(.*\)", "", name)
+    name = re.sub(r"_", " ", name)
+    name = re.sub(r"  ", " ", name)
+    return name
+
+
+def sanity_check_target_after_processing(target):
+    assert(len(target.bbox) == len(target.extra_fields["boxes"]))
+
+
+def convert_od_to_grounding_simple(
+    target, 
+    image_id, 
+    ind_to_class, 
+    disable_shuffle=True, 
+    add_detection_prompt=False, 
+    separation_tokens=" ",
+    caption_prompt=None):
+    """
+    Convert object detection data into grounding data format, on the fly.
+    ind_to_class: {0: "__background__", 1 : "person" ...}, contiguous id
+    """
+
+    def generate_sentence_from_labels(positive_label_list, negative_label_list, disable_shuffle=True):
+        label_to_positions = {}
+        label_list = negative_label_list + positive_label_list
+        if not disable_shuffle:
+            random.shuffle(label_list)
+            assert (caption_prompt is None), "Should not specify caption_prompt when shuffle is enabled!!"  # avoid potential bug
+
+        if add_detection_prompt:
+            pheso_caption = "object detection : "
+        else:
+            pheso_caption = ""
+        
+
+        for index, label in enumerate(label_list):
+            if caption_prompt is not None:
+                pheso_caption += caption_prompt[index]['prefix']
+
+            start_index = len(pheso_caption)
+            if caption_prompt is not None:
+                pheso_caption += clean_name(caption_prompt[index]['name'])
+            else:
+                pheso_caption += clean_name(ind_to_class[label])  # NOTE: slight change...
+            end_index = len(pheso_caption)
+
+            if caption_prompt is not None:
+                pheso_caption += caption_prompt[index]['suffix']
+
+            # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17
+            # label_to_positions: {4: (0, 3), 17: (4, 7)}
+            label_to_positions[label] = [start_index, end_index]
+
+            if index != len(label_list) - 1:
+                pheso_caption += separation_tokens
+
+        return label_to_positions, pheso_caption
+
+    label_list = list(sorted(ind_to_class.keys()))  # do not include the background
+    label_to_positions, pheso_caption = generate_sentence_from_labels(
+        positive_label_list=label_list,
+        negative_label_list=[],
+        disable_shuffle=disable_shuffle
+    )
+
+    new_target = []
+
+    '''
+    Convert into:
+    {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} 
+    tokens_positive is the char position
+    '''
+    areas = target.area()
+    greenlight_span_for_masked_lm_objective = []
+    for i in range(len(target)):
+        new_target_i = {}
+        new_target_i["area"] = areas[i]
+        new_target_i["iscrowd"] = 0
+        new_target_i["image_id"] = image_id
+        new_target_i["category_id"] = target.extra_fields["labels"][i].item()
+        new_target_i["id"] = None
+        new_target_i['bbox'] = target.bbox[i].numpy().tolist()
+
+        label_i = target.extra_fields["labels"][i].item()
+
+        if label_i in label_to_positions:  # NOTE: Only add those that actually appear in the final caption
+            new_target_i["tokens_positive"] = [label_to_positions[label_i]]
+            new_target.append(new_target_i)
+            greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i])
+
+    return new_target, pheso_caption, greenlight_span_for_masked_lm_objective
+
+
+def check_for_positive_overflow(target, ind_to_class, tokenizer, max_seq_length=256):
+    # NOTE: Only call this function for OD data; DO NOT USE IT FOR GROUNDING DATA
+    # NOTE: called only in coco_dt
+
+    # Check if we have too many positive labels
+    # generate a caption by appending the positive labels
+    positive_label_set = set()
+    for i in range(len(target)):
+        label_i = target.extra_fields["labels"][i].item()
+        positive_label_set.add(label_i)
+    positive_label_list = list(positive_label_set)
+
+    # random shuffule so we can sample different annotations at different epochs
+    random.shuffle(positive_label_list)
+
+    kept_lables = []
+    length = 0
+
+    for index, label in enumerate(positive_label_list):
+
+        label_text = clean_name(ind_to_class[label]) + ". " # "dog. "
+
+        tokenized = tokenizer.tokenize(label_text)
+
+        length += len(tokenized)
+
+        if length > max_seq_length:
+            break
+        else:
+            kept_lables.append(label)
+    
+    ## filter boxes
+    keep_box_index = []
+    for i in range(len(target)):
+        label_i = target.extra_fields["labels"][i].item()
+        if label_i in kept_lables:
+            keep_box_index.append(i)
+    
+    keep_box_index = torch.LongTensor(keep_box_index)
+
+    target = target[keep_box_index] ## filter boxes
+
+    return target, length
+
+    
+def convert_object_detection_to_grounding_optimized_for_od(
+        target,
+        image_id,
+        ind_to_class,
+        disable_shuffle,
+        add_detection_prompt,
+        add_detection_prompt_advanced,
+        random_sample_negative,
+        control_probabilities,
+        restricted_negative_list=None,
+        separation_tokens=" ",
+        max_num_labels=-1,
+        max_seq_length=256,
+        tokenizer=None,
+        positive_caption_length=0
+):
+    '''
+    ind_to_class: {0: "__background__", 1 : "person" ...}
+    target:
+
+    restricted_negative_list : for datasets with restricted negatives, sample only the negatives
+
+    Convert object detection data into grounding data format, on the fly.
+
+    Control options:
+        1. add_detection_prompt: add "object detection : " to the front of the prompt
+        2. num_negatives: randomly sampled negative classes
+        3. num_positives: how many positives to keep (-1 means do not cut any)
+
+    Probabilities to generate the control options:
+
+        a. probability_one_negative: only give one negative class to mimic evaluation
+        b. probability_one_positive: only give one positive class to mimic evaluation
+        c. probability_full: add both all positive and all negatives
+        d. other:
+            randomly sample some negatives and some positives
+            The below control options are independent of each other:
+            - probability_random_negative: probability of randomly sample X negatives
+            - probability_random_positive: probability of randomly sample some positives
+    '''
+    if restricted_negative_list is None:
+        valid_negative_indexes = list(ind_to_class.keys())
+    else:
+        valid_negative_indexes = restricted_negative_list
+
+    def generate_senetence_given_labels(
+            positive_label_list,
+            negative_label_list,
+            prompt_engineer_version="v2",
+            disable_shuffle=False,
+            positive_question_probability=0.6,
+            negative_question_probability=0.8,
+            full_question_probability=0.5):
+
+        '''
+        v3: with simple prompt such as "there are", "are there?"
+        v4: try to merge some are there / there are together, to avoid sequence being too long
+        '''
+
+        label_to_positions = {}
+
+        assert (prompt_engineer_version == "v2")
+        num_negatives = len(negative_label_list)
+        num_positives = len(positive_label_list)
+        label_list = negative_label_list + positive_label_list
+        if not disable_shuffle:
+            random.shuffle(label_list)
+
+        if add_detection_prompt:
+            if add_detection_prompt_advanced and (num_negatives == 0 or num_positives == 0) and not disable_shuffle:
+                pheso_caption = "object detection query : "
+            else:
+                pheso_caption = "object detection : "
+        else:
+            pheso_caption = ""
+
+        for index, label in enumerate(label_list):
+
+            start_index = len(pheso_caption)
+
+            pheso_caption += clean_name(ind_to_class[label])  # NOTE: slight change...
+            end_index = len(pheso_caption)
+
+            # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17
+            # label_to_positions: {4: (0, 3), 17: (4, 7)}
+            label_to_positions[label] = [start_index, end_index]
+
+            if index != len(label_list) - 1:
+                pheso_caption += separation_tokens
+
+        return label_to_positions, pheso_caption
+
+    if disable_shuffle:
+        label_list = list(sorted(ind_to_class.keys()))[1:]  # do not include the background
+        label_to_positions, pheso_caption = generate_senetence_given_labels(
+            positive_label_list=label_list,
+            negative_label_list=[],
+            disable_shuffle=True)
+        # print(label_to_positions, pheso_caption)
+    else:
+        positive_label_set = set()
+        for i in range(len(target)):
+            label_i = target.extra_fields["labels"][i].item()
+            positive_label_set.add(label_i)
+
+        full_positive = len(positive_label_set)
+        if max_num_labels <= 0:
+            full_negative = random_sample_negative
+        else:
+            full_negative = max(min(max_num_labels-full_positive, random_sample_negative), 0)
+
+        if full_negative > len(valid_negative_indexes):
+            full_negative = len(valid_negative_indexes)
+
+        num_negatives, num_positives = generate_control_options_given_probabilities(
+            control_probabilities=control_probabilities,
+            full_positive=full_positive,
+            full_negative=full_negative)
+        # num_positives not used
+        
+
+        # Keep some negatives
+        negative_label_list = set()
+        if num_negatives != -1:
+            if num_negatives > len(valid_negative_indexes):
+                num_negatives = len(valid_negative_indexes)
+            for i in np.random.choice(valid_negative_indexes, size=num_negatives, replace=False):
+                # label_sets.add(i)
+                if i not in positive_label_set:
+                    negative_label_list.add(i)
+
+        # Keep all positives; ignoring num_positives
+        positive_label_list = list(positive_label_set)
+        random.shuffle(positive_label_list)
+
+        negative_label_list = list(negative_label_list)  # e.g.: [17, 1, 13] where each number is the class name
+        random.shuffle(negative_label_list)
+
+        # Do a pre-screen. If we cannot afford this many negatives, we will sample less
+        negative_max_length = max_seq_length - positive_caption_length
+        screened_negative_label_list = []
+        for negative_label in negative_label_list:
+            label_text = clean_name(ind_to_class[negative_label]) + ". " # "dog. "
+
+            tokenized = tokenizer.tokenize(label_text)
+            
+            negative_max_length -= len(tokenized)
+
+            if negative_max_length > 0: 
+                screened_negative_label_list.append(negative_label) # keep this negative
+            else:
+                break
+        negative_label_list = screened_negative_label_list
+
+        label_to_positions, pheso_caption = generate_senetence_given_labels(
+            positive_label_list=positive_label_list,
+            negative_label_list=negative_label_list)
+
+    new_target = []
+
+    '''
+    Convert into:
+    {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} 
+    tokens_positive is the char position
+    '''
+    areas = target.area()
+    greenlight_span_for_masked_lm_objective = []
+    for i in range(len(target)):
+        new_target_i = {}
+        new_target_i["area"] = areas[i]
+        new_target_i["iscrowd"] = 0
+        new_target_i["image_id"] = image_id
+        new_target_i["category_id"] = target.extra_fields["labels"][i].item()
+        new_target_i["id"] = None
+        new_target_i['bbox'] = target.bbox[i].numpy().tolist()
+
+        label_i = target.extra_fields["labels"][i].item()
+        new_target_i["original_od_label"] = label_i
+
+        if label_i in label_to_positions:  # NOTE: Only add those that actually appear in the final caption
+            new_target_i["tokens_positive"] = [label_to_positions[label_i]]
+            new_target.append(new_target_i)
+            greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i])
+
+    return new_target, pheso_caption, greenlight_span_for_masked_lm_objective, label_to_positions
+
+
+def generate_control_options_given_probabilities(
+        control_probabilities,
+        full_positive,
+        full_negative):
+    
+    # The function was originally designed to perform data augmentation by randomly dropping negative and positive classes. Later, we decided to only consider dropping negative classes. So the returned 'num_positives' by this function will be ignored.
+
+    outer_prob = random.random()
+
+    probability_one_negative = control_probabilities[0]
+    probability_one_positive = control_probabilities[1]
+    probability_full = control_probabilities[2]
+    probability_drop_positive = control_probabilities[3]
+
+    assert(probability_drop_positive == 0)
+
+    if outer_prob < probability_one_negative:
+        # a. probability_one_negative: only give one negative class to mimic evaluation (10%)
+        num_negatives = 1
+        num_positives = 0
+    elif outer_prob < probability_one_positive + probability_one_negative:
+        # b. probability_one_positive: only give one positive class to mimic evaluation (10%)
+        num_negatives = 0
+        num_positives = 1
+    elif outer_prob < probability_full + probability_one_positive + probability_one_negative:
+        # c. probability_full: add both all positive and all negatives (20%)
+        num_negatives = full_negative
+        num_positives = full_positive
+    else:
+        if random.random() < 1.0:  # - probability_random_negative: probability of randomly sample X negatives (100%)
+            num_negatives = np.random.choice(max(1, full_negative)) + 1  # mininum 1
+        else:
+            num_negatives = full_negative  # Full
+
+        if random.random() < probability_drop_positive:  #
+            num_positives = np.random.choice(max(1, full_positive)) + 1
+        else:
+            num_positives = full_positive  # Full
+
+    return num_negatives, num_positives
diff --git a/maskrcnn_benchmark/data/datasets/phrasecut.py b/maskrcnn_benchmark/data/datasets/phrasecut.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a68262d2372c69ba9e64535014770ce4be98189
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/phrasecut.py
@@ -0,0 +1,8 @@
+import torch
+import torchvision
+import torch.utils.data as data
+from maskrcnn_benchmark.data.datasets.modulated_coco import ModulatedDataset
+
+
+class PhrasecutDetection(ModulatedDataset):
+    pass
diff --git a/maskrcnn_benchmark/data/datasets/pseudo_data.py b/maskrcnn_benchmark/data/datasets/pseudo_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..70f2ac3e78feed8740f4f9aeec7bf57695f555b3
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/pseudo_data.py
@@ -0,0 +1,228 @@
+import torch
+import torch.distributed as dist
+import time
+from torchvision.ops import nms
+import random
+import numpy as np
+from PIL import Image, ImageDraw
+import pdb
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from .modulated_coco import ConvertCocoPolysToMask
+from .tsv import ODTSVDataset, TSVYamlDataset
+from .od_to_grounding import sanity_check_target_after_processing
+from copy import deepcopy
+
+class PseudoData(TSVYamlDataset):
+    def __init__(self,
+                 yaml_file,
+                 transforms,
+                 return_tokens,
+                 return_masks,
+                 tokenizer,
+                 caption_min_box=1,
+                 replace_clean_label=False,
+                 further_screen=False,
+                 caption_conf=0.5,
+                 caption_nms=-1,
+                 pack_random_caption_number=0,
+                 inference_caption=False,
+                 sample_negative_for_grounding_data=-1,
+                 random_pack_prob=-1.0,
+                 no_random_pack_probability=0.0,
+                 safeguard_positive_caption=True,
+                 mlm_obj_for_only_positive=False,
+                 caption_format_version="v1",
+                 local_debug=False,
+                 max_query_len=256,
+                 diver_box_for_vqa=False,
+                 **kwargs
+                 ):
+        super(PseudoData, self).__init__(yaml_file, None, replace_clean_label)
+        self.yaml_file = yaml_file
+        self._transforms = transforms
+        self.max_query_len = max_query_len
+        self.prepare = ConvertCocoPolysToMask(return_masks=return_masks,
+                                              return_tokens=return_tokens,
+                                              tokenizer=tokenizer,
+                                              max_query_len=max_query_len)
+        self.diver_box_for_vqa = diver_box_for_vqa
+        if "qa" in self.yaml_file:
+            assert(self.diver_box_for_vqa) # must diver box
+        self.tokenizer = tokenizer
+        self.caption_min_box = caption_min_box
+        self.replace_clean_label = replace_clean_label
+        self.further_screen = further_screen
+        self.pack_random_caption_number = pack_random_caption_number
+        self.caption_format_version = caption_format_version
+
+        self.caption_conf = caption_conf
+        self.caption_nms = caption_nms
+        self.inference_caption = inference_caption
+        self.sample_negative_for_grounding_data = sample_negative_for_grounding_data
+        self.random_pack_prob = random_pack_prob
+        self.no_random_pack_probability = no_random_pack_probability
+        self.safeguard_positive_caption = safeguard_positive_caption
+        self.mlm_obj_for_only_positive = mlm_obj_for_only_positive
+        self.local_debug = local_debug
+        try:
+            self.rank = dist.get_rank()
+        except:
+            self.rank = 0
+
+    def __len__(self):
+        return super(PseudoData, self).__len__()
+
+    @staticmethod
+    def check_for_overlap(range1, range2):
+        if range1[0] > range2[1] or range2[0] > range1[1]:
+            return False
+        return True
+
+    def divert_boxes(self, anno):
+        # first get answer start and end
+        answer_start = len(anno['text']) + 1 # +1 for the space
+        answer_end = len(anno["caption"])
+
+        question = anno["caption"][:answer_start] # get the question
+
+        mask_start = len(question)
+        # add the mask token
+        mask_token = self.tokenizer.mask_token
+        if mask_token is None:
+            mask_token = 'answer'
+        question += mask_token
+        mask_end = len(question)
+
+        # divert the box
+        for i in range(len(anno["bboxes"])):
+            # check over lap
+            for j in range(len(anno["tokens_positive"][i])): 
+                if self.check_for_overlap(anno["tokens_positive"][i][j], [answer_start, answer_end]):
+                    # if overlap, then divert the box to the mask token
+                    anno["tokens_positive"][i][j] = [mask_start, mask_end]
+        
+        anno["caption"] = question
+        return question, anno
+
+    def __getitem__(self, idx):
+        img, anno, _, scale = super(PseudoData, self).__getitem__(idx)
+        if self.inference_caption:
+            caption = None
+            if isinstance(anno, list):
+                caption = anno[0]["caption"]  # inference mode for bing
+                anno = []
+            elif len(anno) == 1:
+                caption = anno["caption"]  # inference mode for googlecc
+                anno = []
+            else:
+                caption = " ".join(anno["captions"])
+                anno = []
+        else:
+            if self.caption_format_version == "v2":
+                anno = self.convert_anno_from_yiling_to_ours(anno)
+            
+            if self.further_screen:
+                conf = self.caption_conf
+                nms_thre = self.caption_nms
+
+                bboxes = torch.as_tensor(anno["bboxes"]).float()
+                scores = torch.as_tensor(anno["scores"])
+                tokens_positive = anno["tokens_positive"]
+
+                keep = scores > conf
+                scores = scores[keep]
+                bboxes = bboxes[keep]
+                tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]]
+
+                assert (len(tokens_positive) == len(bboxes) == len(scores))
+
+                if len(bboxes) < self.caption_min_box:  # Retry triggered!
+                    return self[np.random.choice(len(self))]
+
+                if nms_thre > 0:
+                    keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre)
+                    scores = scores[keep]
+                    bboxes = bboxes[keep]
+                    tokens_positive = [tokens_positive[i] for i in keep]
+                    assert (len(tokens_positive) == len(bboxes) == len(scores))
+
+                # Write back
+                anno["bboxes"] = bboxes.tolist()
+                anno["scores"] = scores.tolist()
+                anno["tokens_positive"] = tokens_positive
+
+            boxes = torch.as_tensor(anno["bboxes"])
+
+            if len(boxes) < self.caption_min_box:  # Retry triggered!
+                return self[np.random.choice(len(self))]
+
+            target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy")
+            target = target.clip_to_image(remove_empty=True)
+
+            if self.diver_box_for_vqa:
+                caption, anno = self.divert_boxes(anno=anno) # will change caption and "tokens_positive"
+
+            caption = anno["caption"]
+            
+            greenlight_span_for_masked_lm_objective = [(0, len(caption))]
+
+            new_anno = []
+            areas = target.area()
+            for i in range(len(target)):
+                new_anno_i = {}
+                new_anno_i["area"] = areas[i]
+                new_anno_i["iscrowd"] = 0
+                new_anno_i["image_id"] = idx
+                new_anno_i["category_id"] = 1  # following vg and others
+                new_anno_i["id"] = None
+                new_anno_i['bbox'] = target.bbox[i].numpy().tolist()
+                new_anno_i["tokens_positive"] = anno["tokens_positive"][i]
+                new_anno.append(new_anno_i)
+            anno = new_anno
+
+        annotations = {"image_id": idx, "annotations": anno, "caption": caption}
+        annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
+        img, annotations = self.prepare(img, annotations, box_format="xyxy")
+
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+
+        # add additional property
+        for ann in annotations:
+            target.add_field(ann, annotations[ann])
+        
+        # This is the real image_id
+        image_id = self.get_img_id(idx)
+        # Can insert additional field into target if needed
+       
+        sanity_check_target_after_processing(target)
+        
+        return img, target, idx
+
+    def convert_anno_from_yiling_to_ours(self, anno):
+        flatterned_bboxes = []
+        flatterned_tokens_positive = []
+        flatterned_bboxes_scores = []
+        for i in range(len(anno["bboxes"])):
+            # i is the index for entity
+            for j in range(len(anno["bboxes"][i])):
+                # j is the index for each box
+                flatterned_bboxes.append(anno["bboxes"][i][j])
+                flatterned_tokens_positive.append(
+                    anno["tokens_positive"][i])  # Assume this box corresponds to all the token_spans for this entity
+                flatterned_bboxes_scores.append(anno["scores"][i][j])
+        anno["bboxes"] = flatterned_bboxes
+        anno["tokens_positive"] = flatterned_tokens_positive
+        anno["scores"] = flatterned_bboxes_scores
+        return anno
+
+    def get_raw_image(self, idx):
+        image, *_ = super(PseudoData, self).__getitem__(idx)
+        return image
+
+    def get_img_id(self, idx):
+        line_no = self.get_line_no(idx)
+        if self.label_tsv is not None:
+            row = self.label_tsv.seek(line_no)
+            img_id = row[0]
+            return img_id
diff --git a/maskrcnn_benchmark/data/datasets/refexp.py b/maskrcnn_benchmark/data/datasets/refexp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63015aff6919f1c2ea97382bc319f92b742f76a
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/refexp.py
@@ -0,0 +1,88 @@
+import copy
+from collections import defaultdict
+from pathlib import Path
+
+import torch
+import torch.utils.data
+
+import maskrcnn_benchmark.utils.dist as dist
+from maskrcnn_benchmark.layers.set_loss import generalized_box_iou
+
+from .modulated_coco import ModulatedDataset
+
+
+class RefExpDataset(ModulatedDataset):
+    pass
+
+
+class RefExpEvaluator(object):
+    def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5):
+        assert isinstance(k, (list, tuple))
+        refexp_gt = copy.deepcopy(refexp_gt)
+        self.refexp_gt = refexp_gt
+        self.iou_types = iou_types
+        self.img_ids = self.refexp_gt.imgs.keys()
+        self.predictions = {}
+        self.k = k
+        self.thresh_iou = thresh_iou
+
+    def accumulate(self):
+        pass
+
+    def update(self, predictions):
+        self.predictions.update(predictions)
+
+    def synchronize_between_processes(self):
+        all_predictions = dist.all_gather(self.predictions)
+        merged_predictions = {}
+        for p in all_predictions:
+            merged_predictions.update(p)
+        self.predictions = merged_predictions
+
+    def summarize(self):
+        if dist.is_main_process():
+            dataset2score = {
+                "refcoco": {k: 0.0 for k in self.k},
+                "refcoco+": {k: 0.0 for k in self.k},
+                "refcocog": {k: 0.0 for k in self.k},
+            }
+            dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0}
+            for image_id in self.img_ids:
+                ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id)
+                assert len(ann_ids) == 1
+                img_info = self.refexp_gt.loadImgs(image_id)[0]
+
+                target = self.refexp_gt.loadAnns(ann_ids[0])
+                prediction = self.predictions[image_id]
+                assert prediction is not None
+                sorted_scores_boxes = sorted(
+                    zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True
+                )
+                sorted_scores, sorted_boxes = zip(*sorted_scores_boxes)
+                sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes])
+                target_bbox = target[0]["bbox"]
+                converted_bbox = [
+                    target_bbox[0],
+                    target_bbox[1],
+                    target_bbox[2] + target_bbox[0],
+                    target_bbox[3] + target_bbox[1],
+                ]
+                giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4))
+                for k in self.k:
+                    if max(giou[:k]) >= self.thresh_iou:
+                        dataset2score[img_info["dataset_name"]][k] += 1.0
+                dataset2count[img_info["dataset_name"]] += 1.0
+
+            for key, value in dataset2score.items():
+                for k in self.k:
+                    try:
+                        value[k] /= dataset2count[key]
+                    except:
+                        pass
+            results = {}
+            for key, value in dataset2score.items():
+                results[key] = sorted([v for k, v in value.items()])
+                print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n")
+
+            return results
+        return None
diff --git a/maskrcnn_benchmark/data/datasets/tsv.py b/maskrcnn_benchmark/data/datasets/tsv.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b92fb16631aae21bad47c1569b582ea0b6431e
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/tsv.py
@@ -0,0 +1,420 @@
+import os
+import os.path as op
+import json
+# import logging
+import base64
+import yaml
+import errno
+import io
+import math
+from PIL import Image, ImageDraw
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from .box_label_loader import LabelLoader
+
+
+def load_linelist_file(linelist_file):
+    if linelist_file is not None:
+        line_list = []
+        with open(linelist_file, 'r') as fp:
+            for i in fp:
+                line_list.append(int(i.strip()))
+        return line_list
+
+
+def img_from_base64(imagestring):
+    try:
+        img = Image.open(io.BytesIO(base64.b64decode(imagestring)))
+        return img.convert('RGB')
+    except ValueError:
+        return None
+
+
+def load_from_yaml_file(yaml_file):
+    with open(yaml_file, 'r') as fp:
+        return yaml.load(fp, Loader=yaml.CLoader)
+
+
+def find_file_path_in_yaml(fname, root):
+    if fname is not None:
+        if op.isfile(fname):
+            return fname
+        elif op.isfile(op.join(root, fname)):
+            return op.join(root, fname)
+        else:
+            raise FileNotFoundError(
+                errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
+            )
+
+
+def create_lineidx(filein, idxout):
+    idxout_tmp = idxout + '.tmp'
+    with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
+        fsize = os.fstat(tsvin.fileno()).st_size
+        fpos = 0
+        while fpos != fsize:
+            tsvout.write(str(fpos) + "\n")
+            tsvin.readline()
+            fpos = tsvin.tell()
+    os.rename(idxout_tmp, idxout)
+
+
+def read_to_character(fp, c):
+    result = []
+    while True:
+        s = fp.read(32)
+        assert s != ''
+        if c in s:
+            result.append(s[: s.index(c)])
+            break
+        else:
+            result.append(s)
+    return ''.join(result)
+
+
+class TSVFile(object):
+    def __init__(self, tsv_file, generate_lineidx=False):
+        self.tsv_file = tsv_file
+        self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
+        self._fp = None
+        self._lineidx = None
+        # the process always keeps the process which opens the file.
+        # If the pid is not equal to the currrent pid, we will re-open the file.
+        self.pid = None
+        # generate lineidx if not exist
+        if not op.isfile(self.lineidx) and generate_lineidx:
+            create_lineidx(self.tsv_file, self.lineidx)
+
+    def __del__(self):
+        if self._fp:
+            self._fp.close()
+
+    def __str__(self):
+        return "TSVFile(tsv_file='{}')".format(self.tsv_file)
+
+    def __repr__(self):
+        return str(self)
+
+    def num_rows(self):
+        self._ensure_lineidx_loaded()
+        return len(self._lineidx)
+
+    def seek(self, idx):
+        self._ensure_tsv_opened()
+        self._ensure_lineidx_loaded()
+        try:
+            pos = self._lineidx[idx]
+        except:
+            # logging.info('{}-{}'.format(self.tsv_file, idx))
+            raise
+        self._fp.seek(pos)
+        return [s.strip() for s in self._fp.readline().split('\t')]
+
+    def seek_first_column(self, idx):
+        self._ensure_tsv_opened()
+        self._ensure_lineidx_loaded()
+        pos = self._lineidx[idx]
+        self._fp.seek(pos)
+        return read_to_character(self._fp, '\t')
+
+    def get_key(self, idx):
+        return self.seek_first_column(idx)
+
+    def __getitem__(self, index):
+        return self.seek(index)
+
+    def __len__(self):
+        return self.num_rows()
+
+    def _ensure_lineidx_loaded(self):
+        if self._lineidx is None:
+            # logging.info('loading lineidx: {}'.format(self.lineidx))
+            with open(self.lineidx, 'r') as fp:
+                self._lineidx = [int(i.strip()) for i in fp.readlines()]
+
+    def _ensure_tsv_opened(self):
+        if self._fp is None:
+            self._fp = open(self.tsv_file, 'r')
+            self.pid = os.getpid()
+
+        if self.pid != os.getpid():
+            # logging.info('re-open {} because the process id changed'.format(self.tsv_file))
+            self._fp = open(self.tsv_file, 'r')
+            self.pid = os.getpid()
+
+
+class CompositeTSVFile():
+    def __init__(self, file_list, seq_file, root='.'):
+        if isinstance(file_list, str):
+            self.file_list = load_list_file(file_list)
+        else:
+            assert isinstance(file_list, list)
+            self.file_list = file_list
+
+        self.seq_file = seq_file
+        self.root = root
+        self.initialized = False
+        self.initialize()
+
+    def get_key(self, index):
+        idx_source, idx_row = self.seq[index]
+        k = self.tsvs[idx_source].get_key(idx_row)
+        return '_'.join([self.file_list[idx_source], k])
+
+    def num_rows(self):
+        return len(self.seq)
+
+    def __getitem__(self, index):
+        idx_source, idx_row = self.seq[index]
+        return self.tsvs[idx_source].seek(idx_row)
+
+    def __len__(self):
+        return len(self.seq)
+
+    def initialize(self):
+        '''
+        this function has to be called in init function if cache_policy is
+        enabled. Thus, let's always call it in init funciton to make it simple.
+        '''
+        if self.initialized:
+            return
+        self.seq = []
+        with open(self.seq_file, 'r') as fp:
+            for line in fp:
+                parts = line.strip().split('\t')
+                self.seq.append([int(parts[0]), int(parts[1])])
+        self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
+        self.initialized = True
+
+
+def load_list_file(fname):
+    with open(fname, 'r') as fp:
+        lines = fp.readlines()
+    result = [line.strip() for line in lines]
+    if len(result) > 0 and result[-1] == '':
+        result = result[:-1]
+    return result
+
+
+class TSVDataset(object):
+    def __init__(self, img_file, label_file=None, hw_file=None,
+                 linelist_file=None, imageid2idx_file=None):
+        """Constructor.
+        Args:
+            img_file: Image file with image key and base64 encoded image str.
+            label_file: An optional label file with image key and label information.
+                A label_file is required for training and optional for testing.
+            hw_file: An optional file with image key and image height/width info.
+            linelist_file: An optional file with a list of line indexes to load samples.
+                It is useful to select a subset of samples or duplicate samples.
+        """
+        self.img_file = img_file
+        self.label_file = label_file
+        self.hw_file = hw_file
+        self.linelist_file = linelist_file
+
+        self.img_tsv = TSVFile(img_file)
+        self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True)
+        self.hw_tsv = None if hw_file is None else TSVFile(hw_file)
+        self.line_list = load_linelist_file(linelist_file)
+        self.imageid2idx = None
+        if imageid2idx_file is not None:
+            self.imageid2idx = json.load(open(imageid2idx_file, 'r'))
+
+        self.transforms = None
+
+    def __len__(self):
+        if self.line_list is None:
+            if self.imageid2idx is not None:
+                assert self.label_tsv is not None, "label_tsv is None!!!"
+                return self.label_tsv.num_rows()
+            return self.img_tsv.num_rows()
+        else:
+            return len(self.line_list)
+
+    def __getitem__(self, idx):
+        img = self.get_image(idx)
+        img_size = img.size  # w, h
+        annotations = self.get_annotations(idx)
+        # print(idx, annotations)
+        target = self.get_target_from_annotations(annotations, img_size, idx)
+        img, target = self.apply_transforms(img, target)
+
+        if self.transforms is None:
+            return img, target, idx, 1.0
+        else:
+            new_img_size = img.shape[1:]
+            scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1]))
+            return img, target, idx, scale
+
+    def get_line_no(self, idx):
+        return idx if self.line_list is None else self.line_list[idx]
+
+    def get_image(self, idx):
+        line_no = self.get_line_no(idx)
+        if self.imageid2idx is not None:
+            assert self.label_tsv is not None, "label_tsv is None!!!"
+            row = self.label_tsv.seek(line_no)
+            annotations = json.loads(row[1])
+            imageid = annotations["img_id"]
+            line_no = self.imageid2idx[imageid]
+        row = self.img_tsv.seek(line_no)
+        # use -1 to support old format with multiple columns.
+        img = img_from_base64(row[-1])
+        return img
+
+    def get_annotations(self, idx):
+        line_no = self.get_line_no(idx)
+        if self.label_tsv is not None:
+            row = self.label_tsv.seek(line_no)
+            annotations = json.loads(row[1])
+            return annotations
+        else:
+            return []
+
+    def get_target_from_annotations(self, annotations, img_size, idx):
+        # This function will be overwritten by each dataset to
+        # decode the labels to specific formats for each task.
+        return annotations
+
+    def apply_transforms(self, image, target=None):
+        # This function will be overwritten by each dataset to
+        # apply transforms to image and targets.
+        return image, target
+
+    def get_img_info(self, idx):
+        if self.imageid2idx is not None:
+            assert self.label_tsv is not None, "label_tsv is None!!!"
+            line_no = self.get_line_no(idx)
+            row = self.label_tsv.seek(line_no)
+            annotations = json.loads(row[1])
+            return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])}
+
+        if self.hw_tsv is not None:
+            line_no = self.get_line_no(idx)
+            row = self.hw_tsv.seek(line_no)
+            try:
+                # json string format with "height" and "width" being the keys
+                data = json.loads(row[1])
+                if type(data) == list:
+                    return data[0]
+                elif type(data) == dict:
+                    return data
+            except ValueError:
+                # list of strings representing height and width in order
+                hw_str = row[1].split(' ')
+                hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
+                return hw_dict
+
+    def get_img_key(self, idx):
+        line_no = self.get_line_no(idx)
+        # based on the overhead of reading each row.
+        if self.imageid2idx is not None:
+            assert self.label_tsv is not None, "label_tsv is None!!!"
+            row = self.label_tsv.seek(line_no)
+            annotations = json.loads(row[1])
+            return annotations["img_id"]
+
+        if self.hw_tsv:
+            return self.hw_tsv.seek(line_no)[0]
+        elif self.label_tsv:
+            return self.label_tsv.seek(line_no)[0]
+        else:
+            return self.img_tsv.seek(line_no)[0]
+
+
+class TSVYamlDataset(TSVDataset):
+    """ TSVDataset taking a Yaml file for easy function call
+    """
+
+    def __init__(self, yaml_file, root=None, replace_clean_label=False):
+        print("Reading {}".format(yaml_file))
+        self.cfg = load_from_yaml_file(yaml_file)
+        if root:
+            self.root = root
+        else:
+            self.root = op.dirname(yaml_file)
+        img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
+        label_file = find_file_path_in_yaml(self.cfg.get('label', None),
+                                            self.root)
+        hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
+        linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
+                                               self.root)
+        imageid2idx_file = find_file_path_in_yaml(self.cfg.get('imageid2idx', None),
+                                               self.root)
+
+        if replace_clean_label:
+            assert ("raw_label" in label_file)
+            label_file = label_file.replace("raw_label", "clean_label")
+
+        super(TSVYamlDataset, self).__init__(
+            img_file, label_file, hw_file, linelist_file, imageid2idx_file)
+
+
+class ODTSVDataset(TSVYamlDataset):
+    """
+    Generic TSV dataset format for Object Detection.
+    """
+
+    def __init__(self, yaml_file, extra_fields=(), transforms=None,
+                 is_load_label=True, **kwargs):
+        if yaml_file is None:
+            return
+        super(ODTSVDataset, self).__init__(yaml_file)
+
+        self.transforms = transforms
+        self.is_load_label = is_load_label
+        self.attribute_on = False
+        # self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False
+
+        if self.is_load_label:
+            # construct maps
+            jsondict_file = find_file_path_in_yaml(
+                self.cfg.get("labelmap", None), self.root
+            )
+            if jsondict_file is None:
+                jsondict_file = find_file_path_in_yaml(
+                    self.cfg.get("jsondict", None), self.root
+                )
+            if "json" in jsondict_file:
+                jsondict = json.load(open(jsondict_file, 'r'))
+                if "label_to_idx" not in jsondict:
+                    jsondict = {'label_to_idx': jsondict}
+            elif "tsv" in jsondict_file:
+                label_to_idx = {}
+                counter = 1
+                with open(jsondict_file) as f:
+                    for line in f:
+                        label_to_idx[line.strip()] = counter
+                        counter += 1
+                jsondict = {'label_to_idx': label_to_idx}
+            else:
+                assert (0)
+
+            self.labelmap = {}
+            self.class_to_ind = jsondict['label_to_idx']
+            self.class_to_ind['__background__'] = 0
+            self.ind_to_class = {v: k for k, v in self.class_to_ind.items()}
+            self.labelmap['class_to_ind'] = self.class_to_ind
+
+            if self.attribute_on:
+                self.attribute_to_ind = jsondict['attribute_to_idx']
+                self.attribute_to_ind['__no_attribute__'] = 0
+                self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()}
+                self.labelmap['attribute_to_ind'] = self.attribute_to_ind
+
+            self.label_loader = LabelLoader(
+                labelmap=self.labelmap,
+                extra_fields=extra_fields,
+            )
+
+    def get_target_from_annotations(self, annotations, img_size, idx):
+        if isinstance(annotations, list):
+            annotations = {"objects": annotations}
+        if self.is_load_label:
+            return self.label_loader(annotations['objects'], img_size)
+
+    def apply_transforms(self, img, target=None):
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+        return img, target
diff --git a/maskrcnn_benchmark/data/datasets/vg.py b/maskrcnn_benchmark/data/datasets/vg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c94eacd3ee75346ba06a61efdb0f28ae53b82501
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/vg.py
@@ -0,0 +1,267 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import collections
+import json
+import os.path as op
+
+import numpy as np
+import torch
+
+from .tsv import TSVYamlDataset, find_file_path_in_yaml
+from .box_label_loader import BoxLabelLoader
+from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV
+
+
+class VGDetectionTSV(CocoDetectionTSV):
+    pass
+
+
+def sort_key_by_val(dic):
+    sorted_dic = sorted(dic.items(), key=lambda kv: kv[1])
+    return [kv[0] for kv in sorted_dic]
+
+
+def bbox_overlaps(anchors, gt_boxes):
+    """
+    anchors: (N, 4) ndarray of float
+    gt_boxes: (K, 4) ndarray of float
+    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
+    """
+    N = anchors.size(0)
+    K = gt_boxes.size(0)
+
+    gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
+                     (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)
+
+    anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) *
+                    (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)
+
+    boxes = anchors.view(N, 1, 4).expand(N, K, 4)
+    query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
+
+    iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) -
+          torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1)
+    iw[iw < 0] = 0
+
+    ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) -
+          torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1)
+    ih[ih < 0] = 0
+
+    ua = anchors_area + gt_boxes_area - (iw * ih)
+    overlaps = iw * ih / ua
+
+    return overlaps
+
+
+# VG data loader for Danfei Xu's Scene graph focused format.
+# todo: if ordering of classes, attributes, relations changed
+# todo make sure to re-write the obj_classes.txt/rel_classes.txt files
+
+def _box_filter(boxes, must_overlap=False):
+    """ Only include boxes that overlap as possible relations.
+    If no overlapping boxes, use all of them."""
+    overlaps = bbox_overlaps(boxes, boxes).numpy() > 0
+    np.fill_diagonal(overlaps, 0)
+
+    all_possib = np.ones_like(overlaps, dtype=np.bool)
+    np.fill_diagonal(all_possib, 0)
+
+    if must_overlap:
+        possible_boxes = np.column_stack(np.where(overlaps))
+
+        if possible_boxes.size == 0:
+            possible_boxes = np.column_stack(np.where(all_possib))
+    else:
+        possible_boxes = np.column_stack(np.where(all_possib))
+    return possible_boxes
+
+
+class VGTSVDataset(TSVYamlDataset):
+    """
+    Generic TSV dataset format for Object Detection.
+    """
+
+    def __init__(self, yaml_file, extra_fields=None, transforms=None,
+                 is_load_label=True, filter_duplicate_rels=True,
+                 relation_on=False, cv2_output=False, **kwargs):
+        if extra_fields is None:
+            extra_fields = []
+        self.transforms = transforms
+        self.is_load_label = is_load_label
+        self.relation_on = relation_on
+        super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output)
+
+        ignore_attrs = self.cfg.get("ignore_attrs", None)
+        # construct those maps
+        jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
+        jsondict = json.load(open(jsondict_file, 'r'))
+
+        # self.linelist_file
+        if 'train' in op.basename(self.linelist_file):
+            self.split = "train"
+        elif 'test' in op.basename(self.linelist_file) \
+                or 'val' in op.basename(self.linelist_file) \
+                or 'valid' in op.basename(self.linelist_file):
+            self.split = "test"
+        else:
+            raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file))
+        self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train'
+
+        self.class_to_ind = jsondict['label_to_idx']
+        self.ind_to_class = jsondict['idx_to_label']
+        self.class_to_ind['__background__'] = 0
+        self.ind_to_class['0'] = '__background__'
+        self.classes = sort_key_by_val(self.class_to_ind)
+        assert (all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))]))
+
+        # writing obj classes to disk for Neural Motif model building.
+        obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt"
+        if not op.isfile(obj_classes_out_fn):
+            with open(obj_classes_out_fn, 'w') as f:
+                for item in self.classes:
+                    f.write("%s\n" % item)
+
+        self.attribute_to_ind = jsondict['attribute_to_idx']
+        self.ind_to_attribute = jsondict['idx_to_attribute']
+        self.attribute_to_ind['__no_attribute__'] = 0
+        self.ind_to_attribute['0'] = '__no_attribute__'
+        self.attributes = sort_key_by_val(self.attribute_to_ind)
+        assert (all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))]))
+
+        self.relation_to_ind = jsondict['predicate_to_idx']
+        self.ind_to_relation = jsondict['idx_to_predicate']
+        self.relation_to_ind['__no_relation__'] = 0
+        self.ind_to_relation['0'] = '__no_relation__'
+        self.relations = sort_key_by_val(self.relation_to_ind)
+        assert (all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))]))
+
+        # writing rel classes to disk for Neural Motif Model building.
+        rel_classes_out_fn = op.splitext(self.label_file)[0] + '.rel_classes.txt'
+        if not op.isfile(rel_classes_out_fn):
+            with open(rel_classes_out_fn, 'w') as f:
+                for item in self.relations:
+                    f.write("%s\n" % item)
+
+        # label map: minus one because we will add one in BoxLabelLoader
+        self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()}
+        labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root)
+        # self.labelmap_dec = load_labelmap_file(labelmap_file)
+        if self.is_load_label:
+            self.label_loader = BoxLabelLoader(
+                labelmap=self.labelmap,
+                extra_fields=extra_fields,
+                ignore_attrs=ignore_attrs
+            )
+
+        # get frequency prior for relations
+        if self.relation_on:
+            self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy"
+            if self.split == 'train' and not op.exists(self.freq_prior_file):
+                print("Computing frequency prior matrix...")
+                fg_matrix, bg_matrix = self._get_freq_prior()
+                prob_matrix = fg_matrix.astype(np.float32)
+                prob_matrix[:, :, 0] = bg_matrix
+                prob_matrix[:, :, 0] += 1
+                prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
+                np.save(self.freq_prior_file, prob_matrix)
+
+    def _get_freq_prior(self, must_overlap=False):
+        fg_matrix = np.zeros((
+            len(self.classes),
+            len(self.classes),
+            len(self.relations)
+        ), dtype=np.int64)
+
+        bg_matrix = np.zeros((
+            len(self.classes),
+            len(self.classes),
+        ), dtype=np.int64)
+
+        for ex_ind in range(self.__len__()):
+            target = self.get_groundtruth(ex_ind)
+            gt_classes = target.get_field('labels').numpy()
+            gt_relations = target.get_field('relation_labels').numpy()
+            gt_boxes = target.bbox
+
+            # For the foreground, we'll just look at everything
+            try:
+                o1o2 = gt_classes[gt_relations[:, :2]]
+                for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]):
+                    fg_matrix[o1, o2, gtr] += 1
+
+                # For the background, get all of the things that overlap.
+                o1o2_total = gt_classes[np.array(
+                    _box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)]
+                for (o1, o2) in o1o2_total:
+                    bg_matrix[o1, o2] += 1
+            except IndexError as e:
+                assert len(gt_relations) == 0
+
+            if ex_ind % 20 == 0:
+                print("processing {}/{}".format(ex_ind, self.__len__()))
+
+        return fg_matrix, bg_matrix
+
+    def relation_loader(self, relation_triplets, target):
+        # relation_triplets [list of tuples]: M*3
+        # target: BoxList from label_loader
+        if self.filter_duplicate_rels:
+            # Filter out dupes!
+            assert self.split == 'train'
+            all_rel_sets = collections.defaultdict(list)
+            for (o0, o1, r) in relation_triplets:
+                all_rel_sets[(o0, o1)].append(r)
+            relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()]
+
+        # get M*M pred_labels
+        relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
+        for i in range(len(relation_triplets)):
+            subj_id = relation_triplets[i][0]
+            obj_id = relation_triplets[i][1]
+            pred = relation_triplets[i][2]
+            relations[subj_id, obj_id] = int(pred)
+
+        relation_triplets = torch.tensor(relation_triplets)
+        target.add_field("relation_labels", relation_triplets)
+        target.add_field("pred_labels", relations)
+        return target
+
+    def get_target_from_annotations(self, annotations, img_size, idx):
+        if self.is_load_label and annotations:
+            target = self.label_loader(annotations['objects'], img_size)
+            # make sure no boxes are removed
+            assert (len(annotations['objects']) == len(target))
+            if self.split in ["val", "test"]:
+                # add the difficult field
+                target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32))
+            # load relations
+            if self.relation_on:
+                target = self.relation_loader(annotations["relations"], target)
+            return target
+
+    def get_groundtruth(self, idx, call=False):
+        # similar to __getitem__ but without transform
+        img = self.get_image(idx)
+        if self.cv2_output:
+            img_size = img.shape[:2][::-1]  # h, w -> w, h
+        else:
+            img_size = img.size  # w, h
+        annotations = self.get_annotations(idx)
+        target = self.get_target_from_annotations(annotations, img_size, idx)
+        if call:
+            return img, target, annotations
+        else:
+            return target
+
+    def apply_transforms(self, img, target=None):
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+        return img, target
+
+    def map_class_id_to_class_name(self, class_id):
+        return self.classes[class_id]
+
+    def map_attribute_id_to_attribute_name(self, attribute_id):
+        return self.attributes[attribute_id]
+
+    def map_relation_id_to_relation_name(self, relation_id):
+        return self.relations[relation_id]
diff --git a/maskrcnn_benchmark/data/datasets/voc.py b/maskrcnn_benchmark/data/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4288ba5231dc513ec1e33fe952527ba6433fd199
--- /dev/null
+++ b/maskrcnn_benchmark/data/datasets/voc.py
@@ -0,0 +1,134 @@
+import os
+
+import torch
+import torch.utils.data
+from PIL import Image
+import sys
+
+if sys.version_info[0] == 2:
+    import xml.etree.cElementTree as ET
+else:
+    import xml.etree.ElementTree as ET
+
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+
+class PascalVOCDataset(torch.utils.data.Dataset):
+
+    CLASSES = (
+        "__background__ ",
+        "aeroplane",
+        "bicycle",
+        "bird",
+        "boat",
+        "bottle",
+        "bus",
+        "car",
+        "cat",
+        "chair",
+        "cow",
+        "diningtable",
+        "dog",
+        "horse",
+        "motorbike",
+        "person",
+        "pottedplant",
+        "sheep",
+        "sofa",
+        "train",
+        "tvmonitor",
+    )
+
+    def __init__(self, data_dir, split, use_difficult=False, transforms=None):
+        self.root = data_dir
+        self.image_set = split
+        self.keep_difficult = use_difficult
+        self.transforms = transforms
+
+        self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
+        self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
+        self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")
+
+        with open(self._imgsetpath % self.image_set) as f:
+            self.ids = f.readlines()
+        self.ids = [x.strip("\n") for x in self.ids]
+        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
+
+        cls = PascalVOCDataset.CLASSES
+        self.class_to_ind = dict(zip(cls, range(len(cls))))
+
+    def __getitem__(self, index):
+        img_id = self.ids[index]
+        img = Image.open(self._imgpath % img_id).convert("RGB")
+
+        target = self.get_groundtruth(index)
+        target = target.clip_to_image(remove_empty=True)
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target, index
+
+    def __len__(self):
+        return len(self.ids)
+
+    def get_groundtruth(self, index):
+        img_id = self.ids[index]
+        anno = ET.parse(self._annopath % img_id).getroot()
+        anno = self._preprocess_annotation(anno)
+
+        height, width = anno["im_info"]
+        target = BoxList(anno["boxes"], (width, height), mode="xyxy")
+        target.add_field("labels", anno["labels"])
+        target.add_field("difficult", anno["difficult"])
+        return target
+
+    def _preprocess_annotation(self, target):
+        boxes = []
+        gt_classes = []
+        difficult_boxes = []
+        TO_REMOVE = 1
+        
+        for obj in target.iter("object"):
+            difficult = int(obj.find("difficult").text) == 1
+            if not self.keep_difficult and difficult:
+                continue
+            name = obj.find("name").text.lower().strip()
+            bb = obj.find("bndbox")
+            # Make pixel indexes 0-based
+            # Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211"
+            box = [
+                bb.find("xmin").text, 
+                bb.find("ymin").text, 
+                bb.find("xmax").text, 
+                bb.find("ymax").text,
+            ]
+            bndbox = tuple(
+                map(lambda x: x - TO_REMOVE, list(map(int, box)))
+            )
+
+            boxes.append(bndbox)
+            gt_classes.append(self.class_to_ind[name])
+            difficult_boxes.append(difficult)
+
+        size = target.find("size")
+        im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
+
+        res = {
+            "boxes": torch.tensor(boxes, dtype=torch.float32),
+            "labels": torch.tensor(gt_classes),
+            "difficult": torch.tensor(difficult_boxes),
+            "im_info": im_info,
+        }
+        return res
+
+    def get_img_info(self, index):
+        img_id = self.ids[index]
+        anno = ET.parse(self._annopath % img_id).getroot()
+        size = anno.find("size")
+        im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
+        return {"height": im_info[0], "width": im_info[1]}
+
+    def map_class_id_to_class_name(self, class_id):
+        return PascalVOCDataset.CLASSES[class_id]
diff --git a/maskrcnn_benchmark/data/samplers/__init__.py b/maskrcnn_benchmark/data/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f891498f3d66c08a4840de0b12fb03b6834ba4c8
--- /dev/null
+++ b/maskrcnn_benchmark/data/samplers/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .distributed import DistributedSampler
+from .grouped_batch_sampler import GroupedBatchSampler
+from .iteration_based_batch_sampler import IterationBasedBatchSampler
+
+__all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"]
diff --git a/maskrcnn_benchmark/data/samplers/distributed.py b/maskrcnn_benchmark/data/samplers/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b2aa926f61243e77a9e959ef36826c854467fc5
--- /dev/null
+++ b/maskrcnn_benchmark/data/samplers/distributed.py
@@ -0,0 +1,72 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# Code is copy-pasted exactly as in torch.utils.data.distributed.
+# FIXME remove this once c10d fixes the bug it has
+import math
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+
+from maskrcnn_benchmark.utils.comm import shared_random_seed
+
+
+class DistributedSampler(Sampler):
+    """Sampler that restricts data loading to a subset of the dataset.
+    It is especially useful in conjunction with
+    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+    process can pass a DistributedSampler instance as a DataLoader sampler,
+    and load a subset of the original dataset that is exclusive to it.
+    .. note::
+        Dataset is assumed to be of constant size.
+    Arguments:
+        dataset: Dataset used for sampling.
+        num_replicas (optional): Number of processes participating in
+            distributed training.
+        rank (optional): Rank of the current process within num_replicas.
+    """
+
+    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, use_random=False):
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+        self.total_size = self.num_samples * self.num_replicas
+        self.shuffle = shuffle
+        self.use_random = use_random
+
+    def __iter__(self):
+        if self.shuffle:
+            # deterministically shuffle based on epoch
+            _seed = self.epoch
+            if self.use_random:
+                _seed = int(shared_random_seed())
+            g = torch.Generator()
+            g.manual_seed(_seed)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = torch.arange(len(self.dataset)).tolist()
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # subsample
+        offset = self.num_samples * self.rank
+        indices = indices[offset : offset + self.num_samples]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f6985b9ccef6d9a7353e11817a904d309395b82
--- /dev/null
+++ b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import itertools
+
+import torch
+from torch.utils.data.sampler import BatchSampler
+from torch.utils.data.sampler import Sampler
+
+
+class GroupedBatchSampler(BatchSampler):
+    """
+    Wraps another sampler to yield a mini-batch of indices.
+    It enforces that elements from the same group should appear in groups of batch_size.
+    It also tries to provide mini-batches which follows an ordering which is
+    as close as possible to the ordering from the original sampler.
+
+    Arguments:
+        sampler (Sampler): Base sampler.
+        batch_size (int): Size of mini-batch.
+        drop_uneven (bool): If ``True``, the sampler will drop the batches whose
+            size is less than ``batch_size``
+
+    """
+
+    def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
+        if not isinstance(sampler, Sampler):
+            raise ValueError(
+                "sampler should be an instance of "
+                "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+            )
+        self.sampler = sampler
+        self.group_ids = torch.as_tensor(group_ids)
+        assert self.group_ids.dim() == 1
+        self.batch_size = batch_size
+        self.drop_uneven = drop_uneven
+
+        self.groups = torch.unique(self.group_ids).sort(0)[0]
+
+        self._can_reuse_batches = False
+
+    def _prepare_batches(self):
+        dataset_size = len(self.group_ids)
+        # get the sampled indices from the sampler
+        sampled_ids = torch.as_tensor(list(self.sampler))
+        # potentially not all elements of the dataset were sampled
+        # by the sampler (e.g., DistributedSampler).
+        # construct a tensor which contains -1 if the element was
+        # not sampled, and a non-negative number indicating the
+        # order where the element was sampled.
+        # for example. if sampled_ids = [3, 1] and dataset_size = 5,
+        # the order is [-1, 1, -1, 0, -1]
+        order = torch.full((dataset_size,), -1, dtype=torch.int64)
+        order[sampled_ids] = torch.arange(len(sampled_ids))
+
+        # get a mask with the elements that were sampled
+        mask = order >= 0
+
+        # find the elements that belong to each individual cluster
+        clusters = [(self.group_ids == i) & mask for i in self.groups]
+        # get relative order of the elements inside each cluster
+        # that follows the order from the sampler
+        relative_order = [order[cluster] for cluster in clusters]
+        # with the relative order, find the absolute order in the
+        # sampled space
+        permutation_ids = [s[s.sort()[1]] for s in relative_order]
+        # permute each cluster so that they follow the order from
+        # the sampler
+        permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
+
+        # splits each cluster in batch_size, and merge as a list of tensors
+        splits = [c.split(self.batch_size) for c in permuted_clusters]
+        merged = tuple(itertools.chain.from_iterable(splits))
+
+        # now each batch internally has the right order, but
+        # they are grouped by clusters. Find the permutation between
+        # different batches that brings them as close as possible to
+        # the order that we have in the sampler. For that, we will consider the
+        # ordering as coming from the first element of each batch, and sort
+        # correspondingly
+        first_element_of_batch = [t[0].item() for t in merged]
+        # get and inverse mapping from sampled indices and the position where
+        # they occur (as returned by the sampler)
+        inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
+        # from the first element in each batch, get a relative ordering
+        first_index_of_batch = torch.as_tensor(
+            [inv_sampled_ids_map[s] for s in first_element_of_batch]
+        )
+
+        # permute the batches so that they approximately follow the order
+        # from the sampler
+        permutation_order = first_index_of_batch.sort(0)[1].tolist()
+        # finally, permute the batches
+        batches = [merged[i].tolist() for i in permutation_order]
+
+        if self.drop_uneven:
+            kept = []
+            for batch in batches:
+                if len(batch) == self.batch_size:
+                    kept.append(batch)
+            batches = kept
+        return batches
+
+    def __iter__(self):
+        if self._can_reuse_batches:
+            batches = self._batches
+            self._can_reuse_batches = False
+        else:
+            batches = self._prepare_batches()
+        self._batches = batches
+        return iter(batches)
+
+    def __len__(self):
+        if not hasattr(self, "_batches"):
+            self._batches = self._prepare_batches()
+            self._can_reuse_batches = True
+        return len(self._batches)
diff --git a/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..431693eecd2e474dacdbc9eb805dbe2b092234cc
--- /dev/null
+++ b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py
@@ -0,0 +1,31 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from torch.utils.data.sampler import BatchSampler
+
+
+class IterationBasedBatchSampler(BatchSampler):
+    """
+    Wraps a BatchSampler, resampling from it until
+    a specified number of iterations have been sampled
+    """
+
+    def __init__(self, batch_sampler, num_iterations, start_iter=0):
+        self.batch_sampler = batch_sampler
+        self.num_iterations = num_iterations
+        self.start_iter = start_iter
+
+    def __iter__(self):
+        iteration = self.start_iter
+        while iteration <= self.num_iterations:
+            # if the underlying sampler has a set_epoch method, like
+            # DistributedSampler, used for making each process see
+            # a different split of the dataset, then set it
+            if hasattr(self.batch_sampler.sampler, "set_epoch"):
+                self.batch_sampler.sampler.set_epoch(iteration)
+            for batch in self.batch_sampler:
+                iteration += 1
+                if iteration > self.num_iterations:
+                    break
+                yield batch
+
+    def __len__(self):
+        return self.num_iterations
diff --git a/maskrcnn_benchmark/data/transforms/__init__.py b/maskrcnn_benchmark/data/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ce850056fdd7ed45f416bc4ead90f3f7da0073
--- /dev/null
+++ b/maskrcnn_benchmark/data/transforms/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .transforms import Compose
+from .transforms import Resize
+from .transforms import RandomHorizontalFlip
+from .transforms import ToTensor
+from .transforms import Normalize
+
+from .build import build_transforms
diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f66c092e4ca0229b7bd5607c84e7be9ce52eb9f
--- /dev/null
+++ b/maskrcnn_benchmark/data/transforms/build.py
@@ -0,0 +1,45 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from . import transforms as T
+
+
+def build_transforms(cfg, is_train=True):
+    if is_train:
+        if len(cfg.AUGMENT.MULT_MIN_SIZE_TRAIN)>0:
+            min_size = cfg.AUGMENT.MULT_MIN_SIZE_TRAIN
+        else:
+            min_size = cfg.INPUT.MIN_SIZE_TRAIN
+        max_size = cfg.INPUT.MAX_SIZE_TRAIN
+        flip_horizontal_prob = cfg.AUGMENT.FLIP_PROB_TRAIN
+        flip_vertical_prob = cfg.AUGMENT.VERTICAL_FLIP_PROB_TRAIN
+        brightness = cfg.AUGMENT.BRIGHTNESS
+        contrast = cfg.AUGMENT.CONTRAST
+        saturation = cfg.AUGMENT.SATURATION
+        hue = cfg.AUGMENT.HUE
+
+        crop_prob = cfg.AUGMENT.CROP_PROB
+        min_ious = cfg.AUGMENT.CROP_MIN_IOUS
+        min_crop_size = cfg.AUGMENT.CROP_MIN_SIZE
+
+    else:
+        min_size = cfg.INPUT.MIN_SIZE_TEST
+        max_size = cfg.INPUT.MAX_SIZE_TEST
+        flip_horizontal_prob = 0.0
+
+    fix_res = cfg.INPUT.FIX_RES
+    if cfg.INPUT.FORMAT is not '':
+        input_format = cfg.INPUT.FORMAT
+    elif cfg.INPUT.TO_BGR255:
+        input_format = 'bgr255'
+    normalize_transform = T.Normalize(
+        mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format
+    )
+ 
+    transform = T.Compose(
+        [
+            T.Resize(min_size, max_size, restrict=fix_res),
+            T.RandomHorizontalFlip(flip_horizontal_prob),
+            T.ToTensor(),
+            normalize_transform,
+        ]
+    )
+    return transform
diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..3698aec52d9df8bcd9bb73cf2a80294c917898bb
--- /dev/null
+++ b/maskrcnn_benchmark/data/transforms/transforms.py
@@ -0,0 +1,385 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import cv2
+import random
+import numpy as np
+import math
+import torch
+import torchvision
+from torchvision.transforms import functional as F
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+def matrix_iou(a, b, relative=False):
+    """
+    return iou of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+    if relative:
+        ious = area_i / (area_b[:, np.newaxis]+1e-12)
+    else:
+        ious = area_i / (area_a[:, np.newaxis] + area_b - area_i+1e-12)
+    return ious
+
+
+class RACompose(object):
+    def __init__(self, pre_transforms, rand_transforms, post_transforms, concurrent=2):
+        self.preprocess = pre_transforms
+        self.transforms = post_transforms
+        self.rand_transforms = rand_transforms
+        self.concurrent = concurrent
+
+    def __call__(self, image, target):
+        for t in self.preprocess:
+            image, target = t(image, target)
+        for t in random.choices(self.rand_transforms, k=self.concurrent):
+            image = np.array(image)
+            image, target = t(image, target)
+        for t in self.transforms:
+            image, target = t(image, target)
+
+        return image, target
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + "("
+        for t in self.preprocess:
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += "\nRandom select {0} from: (".format(self.concurrent)
+        for t in self.rand_transforms:
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += ")\nThen, apply:"
+        for t in self.transforms:
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += "\n)"
+        return format_string
+
+
+class Compose(object):
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target=None):
+        for t in self.transforms:
+            image, target = t(image, target)
+        if target is None:
+            return image
+        return image, target
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + "("
+        for t in self.transforms:
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += "\n)"
+        return format_string
+
+
+class Resize(object):
+    def __init__(self, min_size, max_size, restrict=False):
+        if not isinstance(min_size, (list, tuple)):
+            min_size = (min_size,)
+        self.min_size = min_size
+        self.max_size = max_size
+        self.restrict = restrict
+
+    # modified from torchvision to add support for max size
+    def get_size(self, image_size):
+        w, h = image_size
+        size = random.choice(self.min_size)
+        max_size = self.max_size
+        if self.restrict:
+            return (size, max_size)
+        if max_size is not None:
+            min_original_size = float(min((w, h)))
+            max_original_size = float(max((w, h)))
+            if max_original_size / min_original_size * size > max_size:
+                size = int(round(max_size * min_original_size / max_original_size))
+
+        if (w <= h and w == size) or (h <= w and h == size):
+            return (h, w)
+
+        if w < h:
+            ow = size
+            oh = int(size * h / w)
+        else:
+            oh = size
+            ow = int(size * w / h)
+
+        return (oh, ow)
+
+    def __call__(self, image, target):
+        if isinstance(image, np.ndarray):
+            image_size = self.get_size(image.shape[:2])
+            image = cv2.resize(image, image_size)
+            new_size = image_size
+        else:
+            image = F.resize(image, self.get_size(image.size))
+            new_size = image.size
+        if target is not None:
+            target = target.resize(new_size)
+        return image, target
+
+
+class RandomHorizontalFlip(object):
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, image, target):
+        if random.random() < self.prob:
+            if isinstance(image, np.ndarray):
+                image = np.fliplr(image)
+            else:
+                image = F.hflip(image)
+            if target is not None:
+                target = target.transpose(0)
+        return image, target
+
+
+class RandomVerticalFlip(object):
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, image, target):
+        if random.random() < self.prob:
+            if isinstance(image, np.ndarray):
+                image = np.flipud(image)
+            else:
+                image = F.vflip(image)
+            target = target.transpose(1)
+        return image, target
+
+class ToTensor(object):
+    def __call__(self, image, target):
+        return F.to_tensor(image), target
+
+
+class Normalize(object):
+    def __init__(self, mean, std, format='rgb'):
+        self.mean = mean
+        self.std = std
+        self.format = format.lower()
+
+    def __call__(self, image, target):
+        if 'bgr' in self.format:
+            image = image[[2, 1, 0]]
+        if '255' in self.format:
+            image = image * 255
+        image = F.normalize(image, mean=self.mean, std=self.std)
+        return image, target
+
+
+class ColorJitter(object):
+    def __init__(self,
+                 brightness=0.0,
+                 contrast=0.0,
+                 saturation=0.0,
+                 hue=0.0,
+                 ):
+        self.color_jitter = torchvision.transforms.ColorJitter(
+            brightness=brightness,
+            contrast=contrast,
+            saturation=saturation,
+            hue=hue,)
+
+    def __call__(self, image, target):
+        image = self.color_jitter(image)
+        return image, target
+
+
+class RandomCrop(object):
+    def __init__(self, prob=0.5, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
+        # 1: return ori img
+        self.prob = prob
+        self.sample_mode = (1, *min_ious, 0)
+        self.min_crop_size = min_crop_size
+
+    def __call__(self, img, target):
+        if random.random() > self.prob:
+            return img, target
+
+        h, w, c = img.shape
+        boxes = target.bbox.numpy()
+        labels = target.get_field('labels')
+
+        while True:
+            mode = random.choice(self.sample_mode)
+            if mode == 1:
+                return img, target
+
+            min_iou = mode
+
+            new_w = random.uniform(self.min_crop_size * w, w)
+            new_h = random.uniform(self.min_crop_size * h, h)
+
+            # h / w in [0.5, 2]
+            if new_h / new_w < 0.5 or new_h / new_w > 2:
+                continue
+
+            left = random.uniform(0, w - new_w)
+            top = random.uniform(0, h - new_h)
+
+            patch = np.array([left, top, left + new_w, top + new_h])
+            overlaps = matrix_iou(patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
+            if overlaps.min() < min_iou:
+                continue
+
+            # center of boxes should inside the crop img
+            center = (boxes[:, :2] + boxes[:, 2:]) / 2
+            mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * ( center[:, 1] < patch[3])
+            if not mask.any():
+                continue
+
+            boxes = boxes[mask]
+            labels = labels[mask]
+
+            # adjust boxes
+            img = img[int(patch[1]):int(patch[3]), int(patch[0]):int(patch[2])]
+
+            boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
+            boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
+            boxes -= np.tile(patch[:2], 2)
+
+            new_target = BoxList(boxes, (img.shape[1], img.shape[0]), mode='xyxy')
+            new_target.add_field('labels', labels)
+            return img, new_target
+
+
+class RandomAffine(object):
+    def __init__(self, prob=0.5, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
+                 borderValue=(127.5, 127.5, 127.5)):
+        self.prob = prob
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = scale
+        self.shear = shear
+        self.borderValue = borderValue
+
+    def __call__(self, img, targets=None):
+        if random.random() > self.prob:
+            return img, targets
+        # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
+        # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
+
+        border = 0  # width of added border (optional)
+        #height = max(img.shape[0], img.shape[1]) + border * 2
+        height, width, _ = img.shape
+        bbox = targets.bbox
+
+        # Rotation and Scale
+        R = np.eye(3)
+        a = random.random() * (self.degrees[1] - self.degrees[0]) + self.degrees[0]
+        # a += random.choice([-180, -90, 0, 90])  # 90deg rotations added to small rotations
+        s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
+        R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
+
+        # Translation
+        T = np.eye(3)
+        T[0, 2] = (random.random() * 2 - 1) * self.translate[0] * img.shape[0] + border  # x translation (pixels)
+        T[1, 2] = (random.random() * 2 - 1) * self.translate[1] * img.shape[1] + border  # y translation (pixels)
+
+        # Shear
+        S = np.eye(3)
+        S[0, 1] = math.tan((random.random() * (self.shear[1] - self.shear[0]) + self.shear[0]) * math.pi / 180)  # x shear (deg)
+        S[1, 0] = math.tan((random.random() * (self.shear[1] - self.shear[0]) + self.shear[0]) * math.pi / 180)  # y shear (deg)
+
+        M = S @ T @ R  # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
+        imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR,
+                                  borderValue=self.borderValue)  # BGR order borderValue
+
+        # Return warped points also
+        if targets:
+            n = bbox.shape[0]
+            points = bbox[:, 0:4]
+            area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])
+
+            # warp points
+            xy = np.ones((n * 4, 3))
+            xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
+            xy = (xy @ M.T)[:, :2].reshape(n, 8)
+
+            # create new boxes
+            x = xy[:, [0, 2, 4, 6]]
+            y = xy[:, [1, 3, 5, 7]]
+            xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+
+            # apply angle-based reduction
+            radians = a * math.pi / 180
+            reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
+            x = (xy[:, 2] + xy[:, 0]) / 2
+            y = (xy[:, 3] + xy[:, 1]) / 2
+            w = (xy[:, 2] - xy[:, 0]) * reduction
+            h = (xy[:, 3] - xy[:, 1]) * reduction
+            xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
+
+            # reject warped points outside of image
+            x1 = np.clip(xy[:,0], 0, width)
+            y1 = np.clip(xy[:,1], 0, height)
+            x2 = np.clip(xy[:,2], 0, width)
+            y2 = np.clip(xy[:,3], 0, height)
+            new_bbox = np.concatenate((x1, y1, x2, y2)).reshape(4, n).T
+            targets.bbox = torch.as_tensor(new_bbox, dtype=torch.float32)
+
+        return imw, targets
+
+
+class RandomErasing:
+    def __init__(self, prob=0.5, era_l=0.02, era_h=1/3, min_aspect=0.3,
+                 mode='const', max_count=1, max_overlap=0.3, max_value=255):
+        self.prob = prob
+        self.era_l = era_l
+        self.era_h = era_h
+        self.min_aspect = min_aspect
+        self.min_count = 1
+        self.max_count = max_count
+        self.max_overlap = max_overlap
+        self.max_value = max_value
+        self.mode = mode.lower()
+        assert self.mode in ['const', 'rand', 'pixel'], 'invalid erase mode: %s' % self.mode
+
+    def _get_pixels(self, patch_size):
+        if self.mode == 'pixel':
+            return np.random.random(patch_size)*self.max_value
+        elif self.mode == 'rand':
+            return np.random.random((1, 1, patch_size[-1]))*self.max_value
+        else:
+            return np.zeros((1, 1, patch_size[-1]))
+
+    def __call__(self, image, target):
+        if random.random() > self.prob:
+            return image, target
+        ih, iw, ic = image.shape
+        ia = ih * iw
+        count = self.min_count if self.min_count == self.max_count else \
+            random.randint(self.min_count, self.max_count)
+        erase_boxes = []
+        for _ in range(count):
+            for try_idx in range(10):
+                erase_area = random.uniform(self.era_l, self.era_h) * ia / count
+                aspect_ratio = math.exp(random.uniform(math.log(self.min_aspect), math.log(1/self.min_aspect)))
+                eh = int(round(math.sqrt(erase_area * aspect_ratio)))
+                ew = int(round(math.sqrt(erase_area / aspect_ratio)))
+                if eh < ih and ew < iw:
+                    x = random.randint(0, iw - ew)
+                    y = random.randint(0, ih - eh)
+                    image[y:y+eh, x:x+ew, :] = self._get_pixels((eh, ew, ic))
+                    erase_boxes.append([x,y,x+ew,y+eh])
+                break
+
+        if target is not None and len(erase_boxes)>0:
+            boxes = target.bbox.numpy()
+            labels = target.get_field('labels')
+            overlap = matrix_iou(np.array(erase_boxes), boxes, relative=True)
+            mask = overlap.max(axis=0)<self.max_overlap
+            boxes = boxes[mask]
+            labels = labels[mask]
+            target.bbox = torch.as_tensor(boxes, dtype=torch.float32)
+            target.add_field('labels', labels)
+
+        return image, target
diff --git a/maskrcnn_benchmark/engine/__init__.py b/maskrcnn_benchmark/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bc96c7a6bf8379e1adfb3e4adf536107b385fa9
--- /dev/null
+++ b/maskrcnn_benchmark/engine/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/maskrcnn_benchmark/engine/alter_trainer.py b/maskrcnn_benchmark/engine/alter_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cc57a62f0e25b1d26af7e7c1d4b0f2cb5ee6b5c
--- /dev/null
+++ b/maskrcnn_benchmark/engine/alter_trainer.py
@@ -0,0 +1,127 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import datetime
+import logging
+import time
+
+import torch
+import torch.distributed as dist
+
+from maskrcnn_benchmark.utils.comm import get_world_size
+from maskrcnn_benchmark.utils.metric_logger import MetricLogger
+
+
+def reduce_loss_dict(all_loss_dict):
+    """
+    Reduce the loss dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    loss_dict, after reduction.
+    """
+    world_size = get_world_size()
+    with torch.no_grad():
+        loss_names = []
+        all_losses = []
+        for loss_dict in all_loss_dict:
+            for k in sorted(loss_dict.keys()):
+                loss_names.append(k)
+                all_losses.append(loss_dict[k])
+        all_losses = torch.stack(all_losses, dim=0)
+        if world_size > 1:
+            dist.reduce(all_losses, dst=0)
+            if dist.get_rank() == 0:
+                # only main process gets accumulated, so only divide by
+                # world_size in this case
+                all_losses /= world_size
+
+        reduced_losses = {}
+        for k, v in zip(loss_names, all_losses):
+            if k not in reduced_losses:
+                reduced_losses[k] = v / len(all_loss_dict)
+            reduced_losses[k] += v / len(all_loss_dict)
+
+    return reduced_losses
+
+
+def do_train(
+        model,
+        data_loader,
+        optimizer,
+        scheduler,
+        checkpointer,
+        device,
+        checkpoint_period,
+        arguments,
+):
+    logger = logging.getLogger("maskrcnn_benchmark.trainer")
+    logger.info("Start training")
+    meters = MetricLogger(delimiter="  ")
+    max_iter = min(len(task_loader) for task_loader in data_loader)
+    start_iter = arguments["iteration"]
+    model.train()
+    start_training_time = time.time()
+    end = time.time()
+    for iteration, task_loader in enumerate(zip(*data_loader), start_iter):
+        data_time = time.time() - end
+        iteration = iteration + 1
+        arguments["iteration"] = iteration
+
+        all_task_loss_dict = []
+        for task, (images, targets, _) in enumerate(task_loader, 1):
+            if all(len(target) < 1 for target in targets):
+                logger.warning('Sampled all negative batches, skip')
+                continue
+
+            images = images.to(device)
+            targets = [target.to(device) for target in targets]
+
+            loss_dict = model(images, targets, task)
+            all_task_loss_dict.append(loss_dict)
+
+        losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = reduce_loss_dict(all_task_loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+        meters.update(loss=losses_reduced, **loss_dict_reduced)
+
+        optimizer.zero_grad()
+        losses.backward()
+        optimizer.step()
+        scheduler.step()
+
+        batch_time = time.time() - end
+        end = time.time()
+        meters.update(time=batch_time, data=data_time)
+
+        eta_seconds = meters.time.global_avg * (max_iter - iteration)
+        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+        if iteration % 20 == 0 or iteration == max_iter:
+            logger.info(
+                meters.delimiter.join(
+                    [
+                        "eta: {eta}",
+                        "iter: {iter}",
+                        "{meters}",
+                        "lr: {lr:.6f}",
+                        "max mem: {memory:.0f}",
+                    ]
+                ).format(
+                    eta=eta_string,
+                    iter=iteration,
+                    meters=str(meters),
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+                )
+            )
+        if iteration % checkpoint_period == 0:
+            checkpointer.save("model_{:07d}".format(iteration), **arguments)
+        if iteration == max_iter:
+            checkpointer.save("model_final", **arguments)
+
+    total_training_time = time.time() - start_training_time
+    total_time_str = str(datetime.timedelta(seconds=total_training_time))
+    logger.info(
+        "Total training time: {} ({:.4f} s / it)".format(
+            total_time_str, total_training_time / (max_iter)
+        )
+    )
diff --git a/maskrcnn_benchmark/engine/evolution.py b/maskrcnn_benchmark/engine/evolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..40b41c3e6550de19e6d06d722121ff59e205f6ce
--- /dev/null
+++ b/maskrcnn_benchmark/engine/evolution.py
@@ -0,0 +1,357 @@
+
+import time
+import pickle
+import logging
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+from collections import OrderedDict
+from yaml import safe_dump
+from yacs.config import load_cfg, CfgNode#, _to_dict
+from maskrcnn_benchmark.config import cfg
+from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus
+from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name
+from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather
+from maskrcnn_benchmark.data.datasets.evaluation import evaluate
+from maskrcnn_benchmark.utils.flops import profile
+
+
+choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x))
+
+
+def gather_candidates(all_candidates):
+    all_candidates = all_gather(all_candidates)
+    all_candidates = [cand for candidates in all_candidates for cand in candidates]
+    return list(set(all_candidates))
+
+
+def gather_stats(all_candidates):
+    all_candidates = all_gather(all_candidates)
+    reduced_statcs = {}
+    for candidates in all_candidates:
+        reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists
+    return reduced_statcs
+
+
+def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE):
+    model.eval()
+    results_dict = {}
+    cpu_device = torch.device("cpu")
+    for _, batch in enumerate(data_loader):
+        images, targets, image_ids = batch
+        with torch.no_grad():
+            output = model(images.to(device), rngs=rngs)
+            output = [o.to(cpu_device) for o in output]
+        results_dict.update(
+            {img_id: result for img_id, result in zip(image_ids, output)}
+        )
+    return results_dict
+
+
+def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500):
+    for name, param in model.named_buffers():
+        if 'running_mean' in name:
+            nn.init.constant_(param, 0)
+        if 'running_var' in name:
+            nn.init.constant_(param, 1)
+
+    model.train()
+    for iteration, (images, targets, _) in enumerate(data_loader, 1):
+        images = images.to(device)
+        targets = [target.to(device) for target in targets]
+        with torch.no_grad():
+            loss_dict = model(images, targets, rngs)
+        if iteration >= max_iter:
+            break
+
+    return model
+
+
+def inference(
+        model,
+        rngs,
+        data_loader,
+        iou_types=("bbox",),
+        box_only=False,
+        device="cuda",
+        expected_results=(),
+        expected_results_sigma_tol=4,
+        output_folder=None,
+):
+
+    # convert to a torch.device for efficiency
+    device = torch.device(device)
+    dataset = data_loader.dataset
+    predictions = compute_on_dataset(model, rngs, data_loader, device)
+    # wait for all processes to complete before measuring the time
+    synchronize()
+
+    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
+    if not is_main_process():
+        return
+
+    extra_args = dict(
+        box_only=box_only,
+        iou_types=iou_types,
+        expected_results=expected_results,
+        expected_results_sigma_tol=expected_results_sigma_tol,
+    )
+
+    return evaluate(dataset=dataset,
+                    predictions=predictions,
+                    output_folder=output_folder,
+                    **extra_args)
+
+
+def fitness(cfg, model, rngs, val_loaders):
+    iou_types = ("bbox",)
+    if cfg.MODEL.MASK_ON:
+        iou_types = iou_types + ("segm",)
+    for data_loader_val in val_loaders:
+        results = inference(
+            model,
+            rngs,
+            data_loader_val,
+            iou_types=iou_types,
+            box_only=False,
+            device=cfg.MODEL.DEVICE,
+            expected_results=cfg.TEST.EXPECTED_RESULTS,
+            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
+        )
+        synchronize()
+
+    return results
+
+
+class EvolutionTrainer(object):
+    def __init__(self, cfg, model, flops_limit=None, is_distributed=True):
+
+        self.log_dir = cfg.OUTPUT_DIR
+        self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth')
+        self.is_distributed = is_distributed
+
+        self.states = model.module.mix_nums if is_distributed else model.mix_nums
+        self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict()))
+        self.flops_limit = flops_limit
+        self.model = model
+
+        self.candidates = []
+        self.vis_dict = {}
+
+        self.max_epochs = cfg.SEARCH.MAX_EPOCH
+        self.select_num = cfg.SEARCH.SELECT_NUM
+        self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size()
+        self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size()
+        self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size()
+        self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size()
+
+        self.keep_top_k = {self.select_num:[], 50:[]}
+        self.epoch=0
+        self.cfg = cfg
+
+    def save_checkpoint(self):
+        if not is_main_process():
+            return
+        if not os.path.exists(self.log_dir):
+            os.makedirs(self.log_dir)
+        info = {}
+        info['candidates'] = self.candidates
+        info['vis_dict'] = self.vis_dict
+        info['keep_top_k'] = self.keep_top_k
+        info['epoch'] = self.epoch
+        torch.save(info, self.checkpoint_name)
+        print('Save checkpoint to', self.checkpoint_name)
+
+    def load_checkpoint(self):
+        if not os.path.exists(self.checkpoint_name):
+            return False
+        info = torch.load(self.checkpoint_name)
+        self.candidates = info['candidates']
+        self.vis_dict = info['vis_dict']
+        self.keep_top_k = info['keep_top_k']
+        self.epoch = info['epoch']
+        print('Load checkpoint from', self.checkpoint_name)
+        return True
+
+    def legal(self, cand):
+        assert isinstance(cand,tuple) and len(cand)==len(self.states)
+        if cand in self.vis_dict:
+            return False
+
+        if self.flops_limit is not None:
+            net = self.model.module.backbone if self.is_distributed else self.model.backbone
+            inp = (1, 3, 224, 224)
+            flops, params = profile(net, inp, extra_args={'paths': list(cand)})
+            flops = flops/1e6
+            print('flops:',flops)
+            if flops>self.flops_limit:
+                return False
+
+        return True
+
+    def update_top_k(self, candidates, *, k, key, reverse=False):
+        assert k in self.keep_top_k
+        # print('select ......')
+        t = self.keep_top_k[k]
+        t += candidates
+        t.sort(key=key,reverse=reverse)
+        self.keep_top_k[k]=t[:k]
+
+    def eval_candidates(self, train_loader, val_loader):
+        for cand in self.candidates:
+            t0 = time.time()
+
+            # load back supernet state dict
+            self.model.load_state_dict(self.supernet_state_dict)
+            # bn_statistic
+            model = bn_statistic(self.model, list(cand), train_loader)
+            # fitness
+            evals = fitness(cfg, model, list(cand), val_loader)
+
+            if is_main_process():
+                acc = evals[0].results['bbox']['AP']
+                self.vis_dict[cand] = acc
+                print('candiate ', cand)
+                print('time: {}s'.format(time.time() - t0))
+                print('acc ', acc)
+
+    def stack_random_cand(self, random_func, *, batchsize=10):
+        while True:
+            cands = [random_func() for _ in range(batchsize)]
+            for cand in cands:
+                yield cand
+
+    def random_can(self, num):
+        # print('random select ........')
+        candidates = []
+        cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states))
+        while len(candidates)<num:
+            cand = next(cand_iter)
+
+            if not self.legal(cand):
+                continue
+            candidates.append(cand)
+            #print('random {}/{}'.format(len(candidates),num))
+
+        # print('random_num = {}'.format(len(candidates)))
+        return candidates
+
+    def get_mutation(self, k, mutation_num, m_prob):
+        assert k in self.keep_top_k
+        # print('mutation ......')
+        res = []
+        iter = 0
+        max_iters = mutation_num*10
+
+        def random_func():
+            cand = list(choice(self.keep_top_k[k]))
+            for i in range(len(self.states)):
+                if np.random.random_sample()<m_prob:
+                    cand[i] = np.random.randint(self.states[i])
+            return tuple(cand)
+
+        cand_iter = self.stack_random_cand(random_func)
+        while len(res)<mutation_num and max_iters>0:
+            cand = next(cand_iter)
+            if not self.legal(cand):
+                continue
+            res.append(cand)
+            #print('mutation {}/{}'.format(len(res),mutation_num))
+            max_iters-=1
+
+        # print('mutation_num = {}'.format(len(res)))
+        return res
+
+    def get_crossover(self, k, crossover_num):
+        assert k in self.keep_top_k
+        # print('crossover ......')
+        res = []
+        iter = 0
+        max_iters = 10 * crossover_num
+
+        def random_func():
+            p1=choice(self.keep_top_k[k])
+            p2=choice(self.keep_top_k[k])
+            return tuple(choice([i,j]) for i,j in zip(p1,p2))
+
+        cand_iter = self.stack_random_cand(random_func)
+        while len(res)<crossover_num and max_iters>0:
+            cand = next(cand_iter)
+            if not self.legal(cand):
+                continue
+            res.append(cand)
+            #print('crossover {}/{}'.format(len(res),crossover_num))
+            max_iters-=1
+
+        # print('crossover_num = {}'.format(len(res)))
+        return res
+
+    def train(self, train_loader, val_loader):
+        logger = logging.getLogger("maskrcnn_benchmark.evolution")
+
+        if not self.load_checkpoint():
+            self.candidates = gather_candidates(self.random_can(self.population_num))
+
+        while self.epoch<self.max_epochs:
+            self.eval_candidates(train_loader, val_loader)
+            self.vis_dict = gather_stats(self.vis_dict)
+
+            self.update_top_k(self.candidates, k=self.select_num, key=lambda x:1-self.vis_dict[x])
+            self.update_top_k(self.candidates, k=50, key=lambda x:1-self.vis_dict[x])
+
+            if is_main_process():
+                logger.info('Epoch {} : top {} result'.format(self.epoch+1, len(self.keep_top_k[self.select_num])))
+                for i,cand in enumerate(self.keep_top_k[self.select_num]):
+                    logger.info('     No.{} {} perf = {}'.format(i+1, cand, self.vis_dict[cand]))
+
+            mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob))
+            crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num))
+            rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover)))
+
+            self.candidates = mutation + crossover + rand
+
+            self.epoch+=1
+            self.save_checkpoint()
+
+    def save_candidates(self, cand, template):
+        paths = self.keep_top_k[self.select_num][cand-1]
+
+        with open(template, "r") as f:
+            super_cfg = load_cfg(f)
+
+        search_spaces = {}
+        for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH:
+            search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops]
+        search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP
+
+        layer_setup = []
+        for i, layer in enumerate(search_layers):
+            name, setup = get_layer_name(layer, search_spaces)
+            if not isinstance(name, list):
+                name = [name]
+            name = name[paths[i]]
+
+            layer_setup.append("('{}', {})".format(name, str(setup)[1:-1]))
+        super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup
+
+        cand_cfg = _to_dict(super_cfg)
+        del cand_cfg['MODEL']['BACKBONE']['LAYER_SEARCH']
+        with open(os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace('.yaml','_cand{}.yaml'.format(cand)), 'w') as f:
+            f.writelines(safe_dump(cand_cfg))
+
+        super_weight = self.supernet_state_dict
+        cand_weight = OrderedDict()
+        cand_keys = ['layers.{}.ops.{}'.format(i, c) for i, c in enumerate(paths)]
+
+        for key, val in super_weight.items():
+            if 'ops' in key:
+                for ck in cand_keys:
+                    if ck in key:
+                        cand_weight[key.replace(ck,ck.split('.ops.')[0])] = val
+            else:
+                cand_weight[key] = val
+
+        torch.save({'model':cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, 'init_cand{}.pth'.format(cand)))
diff --git a/maskrcnn_benchmark/engine/inference.py b/maskrcnn_benchmark/engine/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ad3536ddc77558a249c47fb9d6e6671b07ed5a
--- /dev/null
+++ b/maskrcnn_benchmark/engine/inference.py
@@ -0,0 +1,623 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import datetime
+import logging
+import time
+import os
+import re
+
+import torch
+from tqdm import tqdm
+from collections import defaultdict
+
+from maskrcnn_benchmark.data.datasets.evaluation import evaluate, im_detect_bbox_aug
+from ..utils.comm import is_main_process
+from ..utils.comm import all_gather
+from ..utils.comm import synchronize
+import pdb
+from maskrcnn_benchmark.data.datasets.evaluation.flickr.flickr_eval import FlickrEvaluator
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+import matplotlib.pyplot as plt
+import matplotlib.pylab as pylab
+from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file
+def imshow(img, file_name = "tmp.jpg"):
+    plt.imshow(img[:, :, [2, 1, 0]])
+    plt.axis("off")
+    #plt.figtext(0.5, 0.09, "test", wrap=True, horizontalalignment='center', fontsize=20)
+    plt.savefig(file_name)
+def load(url_or_file_name):
+    try:
+        response = requests.get(url_or_file_name)
+    except:
+        response = None
+    if response is None:
+        pil_image = Image.open(url_or_file_name).convert("RGB")
+    else:
+        pil_image = Image.open(BytesIO(response.content)).convert("RGB")
+    # convert to BGR format
+    image = np.array(pil_image)[:, :, [2, 1, 0]]
+    return image
+def inference_default(
+        model,
+        data_loader,
+        dataset_name,
+        iou_types=("bbox",),
+        box_only=False,
+        device="cuda",
+        expected_results=(),
+        expected_results_sigma_tol=4,
+        output_folder=None,
+        cfg=None
+):
+    # convert to a torch.device for efficiency
+    device = torch.device(device)
+    num_devices = (
+        torch.distributed.get_world_size()
+        if torch.distributed.is_initialized()
+        else 1
+    )
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    dataset = data_loader.dataset
+    logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
+    start_time = time.time()
+
+    model.eval()
+    results_dict = {}
+    cpu_device = torch.device("cpu")
+    for i, batch in enumerate(tqdm(data_loader)):
+        images, targets, image_ids, *_ = batch
+        with torch.no_grad():
+            if cfg.TEST.USE_MULTISCALE:
+                output = im_detect_bbox_aug(model, images, device)
+            else:
+                output = model(images.to(device))
+            output = [o.to(cpu_device) for o in output]
+        results_dict.update(
+            {img_id: result for img_id, result in zip(image_ids, output)}
+        )
+    predictions = results_dict
+    # wait for all processes to complete before measuring the time
+    synchronize()
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=total_time))
+    logger.info(
+        "Total inference time: {} ({} s / img per device, on {} devices)".format(
+            total_time_str, total_time * num_devices / len(dataset), num_devices
+        )
+    )
+
+    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
+    if not is_main_process():
+        return None
+
+    if output_folder:
+        torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
+
+    extra_args = dict(
+        box_only=box_only,
+        iou_types=iou_types,
+        expected_results=expected_results,
+        expected_results_sigma_tol=expected_results_sigma_tol,
+    )
+    return evaluate(dataset=dataset, predictions=predictions, output_folder=output_folder, **extra_args)
+
+
+def clean_name(name):
+    name = re.sub(r"\(.*\)", "", name)
+    name = re.sub(r"_", " ", name)
+    name = re.sub(r"  ", " ", name)
+    return name
+
+
+def create_one_hot_dict(labels, no_minus_one_for_one_hot = False):
+    positive_map_token_to_label = defaultdict(int)
+    positive_map_label_to_token = defaultdict(int)
+
+    for i in range(len(labels)):
+        positive_map_token_to_label[i] = labels[i]
+        positive_map_label_to_token[labels[i]] = i
+
+    if no_minus_one_for_one_hot:
+        positive_map_token_to_label = defaultdict(int)
+        positive_map_label_to_token = defaultdict(int)
+
+        for i in range(len(labels)):
+            positive_map_token_to_label[i+1] = labels[i]
+            positive_map_label_to_token[labels[i]] = i + 1
+
+    return positive_map_token_to_label, positive_map_label_to_token
+
+
+def create_positive_dict(tokenized, tokens_positive, labels):
+    """construct a dictionary such that positive_map[i] = j, iff token i is mapped to j label"""
+    positive_map = defaultdict(int)
+
+    # Additionally, have positive_map_label_to_tokens
+    positive_map_label_to_token = defaultdict(list)
+
+    for j, tok_list in enumerate(tokens_positive):
+        for (beg, end) in tok_list:
+            beg_pos = tokenized.char_to_token(beg)
+            end_pos = tokenized.char_to_token(end - 1)
+            if beg_pos is None:
+                try:
+                    beg_pos = tokenized.char_to_token(beg + 1)
+                    if beg_pos is None:
+                        beg_pos = tokenized.char_to_token(beg + 2)
+                except:
+                    beg_pos = None
+            if end_pos is None:
+                try:
+                    end_pos = tokenized.char_to_token(end - 2)
+                    if end_pos is None:
+                        end_pos = tokenized.char_to_token(end - 3)
+                except:
+                    end_pos = None
+            if beg_pos is None or end_pos is None:
+                continue
+
+            assert beg_pos is not None and end_pos is not None
+            for i in range(beg_pos, end_pos + 1):
+                positive_map[i] = labels[j]  # because the labels starts from 1
+                positive_map_label_to_token[labels[j]].append(i)
+            # positive_map[j, beg_pos : end_pos + 1].fill_(1)
+    return positive_map, positive_map_label_to_token  # / (positive_map.sum(-1)[:, None] + 1e-6)
+
+def chunks(lst, n):
+    """Yield successive n-sized chunks from lst."""
+    all_ = []
+    for i in range(0, len(lst), n):
+        data_index = lst[i:i + n]
+        all_.append(data_index)
+    counter = 0
+    for i in all_:
+        counter += len(i)
+    assert(counter == len(lst))
+
+    return all_
+
+def create_queries_and_maps_from_dataset(dataset, cfg):
+    categories = dataset.categories()
+    #one_hot = dataset.one_hot
+
+    labels = []
+    label_list = []
+    keys = list(categories.keys())
+    keys.sort()
+    for i in keys:
+        labels.append(i)
+        label_list.append(categories[i])
+
+    if cfg.TEST.CHUNKED_EVALUATION != -1:
+        labels = chunks(labels, cfg.TEST.CHUNKED_EVALUATION)
+        label_list = chunks(label_list, cfg.TEST.CHUNKED_EVALUATION)
+    else:
+        labels = [labels]
+        label_list = [label_list]
+
+    all_queries = []
+    all_positive_map_label_to_token = []
+
+    for i in range(len(labels)):
+        labels_i = labels[i]
+        label_list_i = label_list[i]
+        query_i, positive_map_label_to_token_i = create_queries_and_maps(
+            labels_i, label_list_i, additional_labels = cfg.DATASETS.SUPRESS_QUERY if cfg.DATASETS.USE_SUPRESS_QUERY else None, cfg = cfg)
+        
+        all_queries.append(query_i)
+        all_positive_map_label_to_token.append(positive_map_label_to_token_i)
+    print("All queries", all_queries)
+    return all_queries, all_positive_map_label_to_token
+
+def create_queries_and_maps(labels, label_list, additional_labels = None, cfg = None):
+
+    # Clean label list
+    original_label_list = label_list.copy()
+    label_list = [clean_name(i) for i in label_list]
+    # Form the query and get the mapping
+    tokens_positive = []
+    start_i = 0
+    end_i = 0
+    objects_query = ""
+
+    # sep between tokens, follow training
+    separation_tokens = cfg.DATASETS.SEPARATION_TOKENS
+    
+    caption_prompt = cfg.DATASETS.CAPTION_PROMPT
+    if caption_prompt is not None and isinstance(caption_prompt, str):
+        caption_prompt = load_from_yaml_file(caption_prompt)
+    use_caption_prompt = cfg.DATASETS.USE_CAPTION_PROMPT and caption_prompt is not None
+    for _index, label in enumerate(label_list):
+        if use_caption_prompt:
+            objects_query += caption_prompt[_index]["prefix"]
+        
+        start_i = len(objects_query)
+
+        if use_caption_prompt:
+            objects_query += caption_prompt[_index]["name"]
+        else:
+            objects_query += label
+        
+        end_i = len(objects_query)
+        tokens_positive.append([(start_i, end_i)])  # Every label has a [(start, end)]
+        
+        if use_caption_prompt:
+            objects_query += caption_prompt[_index]["suffix"]
+
+        if _index != len(label_list) - 1:
+            objects_query += separation_tokens
+    
+    if additional_labels is not None:
+        objects_query += separation_tokens
+        for _index, label in enumerate(additional_labels):
+            objects_query += label
+            if _index != len(additional_labels) - 1:
+                objects_query += separation_tokens
+
+    print(objects_query)
+
+    from transformers import AutoTokenizer
+    # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+    if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
+        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+        tokenized = tokenizer(objects_query, return_tensors="pt")
+    elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+        from transformers import CLIPTokenizerFast
+        if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+            tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                        from_slow=True, mask_token='ðŁĴij</w>')
+        else:
+            tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                        from_slow=True)
+        tokenized = tokenizer(objects_query,
+                              max_length=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+                              truncation=True,
+                              return_tensors="pt")
+    else:
+        tokenizer = None
+        raise NotImplementedError
+
+    # Create the mapping between tokenized sentence and the original label
+    positive_map_token_to_label, positive_map_label_to_token = create_positive_dict(tokenized, tokens_positive,
+                                                                                        labels=labels)  # from token position to original label
+    return objects_query, positive_map_label_to_token
+
+def create_positive_map_label_to_token_from_positive_map(positive_map, plus = 0):
+    positive_map_label_to_token = {}
+    for i in range(len(positive_map)):
+        positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
+    return positive_map_label_to_token
+
+
+
+def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
+    all_predictions = all_gather(predictions_per_gpu)
+    if not is_main_process():
+        return
+    # merge the list of dicts
+    predictions = {}
+    for p in all_predictions:
+        predictions.update(p)
+    # convert a dict where the key is the index in a list
+    image_ids = list(sorted(predictions.keys()))
+    if len(image_ids) != image_ids[-1] + 1:
+        logger = logging.getLogger("maskrcnn_benchmark.inference")
+        logger.warning(
+            "Number of images that were gathered from multiple processes is not "
+            "a contiguous set. Some images might be missing from the evaluation"
+        )
+
+    # convert to a list
+    predictions = [predictions[i] for i in image_ids]
+    return predictions
+
+def resize_box(output, targets):
+    if isinstance(targets[0], dict):
+        orig_target_sizes = targets[0]["orig_size"].unsqueeze(0)
+    else:
+        orig_target_sizes = torch.stack([targets[0].extra_fields["orig_size"] for _ in range(1)], dim=0)
+    img_h, img_w = orig_target_sizes.unbind(1)
+    return output.resize((img_w, img_h))
+
+def flickr_post_process(output, targets, positive_map_label_to_token, plus):
+    output = resize_box(output, targets)
+    scores, indices = torch.topk(output.extra_fields["scores"], k = len(output.extra_fields["scores"]), sorted=True)
+    boxes = output.bbox.tolist()
+    boxes = [boxes[i] for i in indices]
+    labels = [output.extra_fields["labels"][i] for i in indices]
+    output_boxes = [[] for i in range(len(positive_map_label_to_token))]
+    output_scores = [[] for i in range(len(positive_map_label_to_token))]
+    for i in range(len(boxes)):
+        output_boxes[labels[i] - plus].append(boxes[i])
+        output_scores[labels[i] - plus].append(scores[i])
+    for i in output_boxes:
+        i.append([0.0, 0.0, 0.0, 0.0])
+    image_ids = [t.extra_fields["original_img_id"] for t in targets]
+    sentence_ids = [t.extra_fields["sentence_id"] for t in targets]
+
+    return {"image_id": image_ids[0], "sentence_id": sentence_ids[0], "boxes": output_boxes, "scores": output_scores}
+
+def build_flickr_evaluator(cfg):
+    evaluator = FlickrEvaluator(
+        "DATASET/flickr30k/flickr30k/", # Hard written!!
+        subset="test" if "test" in cfg.DATASETS.TEST[0]  else "val",
+        merge_boxes=cfg.DATASETS.FLICKR_GT_TYPE == "merged")
+    return evaluator
+
+def build_lvis_evaluator(ann_file, fixed_ap=True):
+    from maskrcnn_benchmark.data.datasets.evaluation.lvis.lvis import LVIS
+    from maskrcnn_benchmark.data.datasets.evaluation.lvis.lvis_eval import LvisEvaluatorFixedAP, LvisEvaluator
+    evaluator = LvisEvaluatorFixedAP(LVIS(ann_file), fixed_ap=fixed_ap)
+    #evaluator = LvisEvaluator(LVIS(ann_file), iou_types=['segm', 'bbox'])
+    return evaluator
+
+def write_lvis_results(results, output_file_name):
+    lines = []
+    lines.append("metric, avg ")
+    for each_result in results:
+        metric_string = " ".join(each_result.split(" ")[:-2])
+        number = each_result.split(" ")[-1]
+        each_result = metric_string + ", " + number + " "
+        lines.append(each_result)
+
+    string_to_write = "\n".join(lines) + "\n"
+    with open(output_file_name, "w") as f:
+        f.write(string_to_write)
+    return
+
+def write_flickr_results(results, output_file_name):
+    '''
+    {'Recall@1_all': 0.8394651146677753, 'Recall@1_animals': 0.9177820267686424, 'Recall@1_bodyparts': 0.7097966728280961, ...}
+    '''
+    lines = []
+    lines.append("metric, avg ")
+    for each_metric, number in results.items():
+        each_result = each_metric + ", " + str(number) + " "
+        lines.append(each_result)
+
+    string_to_write = "\n".join(lines) + "\n"
+    with open(output_file_name, "w") as f:
+        f.write(string_to_write)
+    return
+
+def inference(
+        model,
+        data_loader,
+        dataset_name,
+        iou_types=("bbox",),
+        box_only=False,
+        device="cuda",
+        expected_results=(),
+        expected_results_sigma_tol=4,
+        output_folder=None,
+        cfg=None,
+        verbose=True,
+        visualizer = None
+):
+    # convert to a torch.device for efficiency
+    try:
+        device = torch.device(device)
+    except:
+        device = device
+    num_devices = (
+        torch.distributed.get_world_size()
+        if torch.distributed.is_initialized()
+        else 1
+    )
+    logger = logging.getLogger("maskrcnn_benchmark.inference")
+    dataset = data_loader.dataset
+    if verbose:
+        logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
+    start_time = time.time()
+
+    task = cfg.TEST.EVAL_TASK
+
+    if not task:
+        return inference_default(model, data_loader, dataset_name, iou_types, box_only, device, expected_results, expected_results_sigma_tol, output_folder, cfg)
+        
+    if cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT:
+        assert task == 'detection'
+        categories = dataset.categories()
+
+        keys = list(categories.keys())
+        keys.sort()
+        all_queries = [[categories[k] for k in keys]]
+        all_positive_map_label_to_token = [{k: [i] for i, k in enumerate(keys)}]
+    elif task == "detection":
+        all_queries, all_positive_map_label_to_token = create_queries_and_maps_from_dataset(dataset, cfg)
+    elif task == "grounding":
+        all_queries = [None]
+        all_positive_map_label_to_token = [None]
+    else:
+        assert(0)
+
+    '''
+    Build Dataset Sepecific Evaluator
+    '''
+    if "flickr" in cfg.DATASETS.TEST[0]:
+        evaluator = build_flickr_evaluator(cfg)
+    elif "lvis" in cfg.DATASETS.TEST[0]:
+        evaluator = build_lvis_evaluator(dataset.ann_file, fixed_ap=not cfg.DATASETS.LVIS_USE_NORMAL_AP)
+    else:
+        evaluator = None
+
+    model.eval()
+    results_dict = {}
+    cpu_device = torch.device("cpu")
+    if verbose:
+        _iterator = tqdm(data_loader)
+    else:
+        _iterator = data_loader
+    for i, batch in enumerate(_iterator):
+        if i == cfg.TEST.SUBSET:
+            break
+        images, targets, image_ids, *_ = batch
+
+        all_output = []
+        mdetr_style_output = []
+        with torch.no_grad():
+            if cfg.TEST.USE_MULTISCALE:
+                query_time = len(all_queries)
+                for query_i in range(query_time):
+                    if task == "detection":
+                        captions = [all_queries[query_i] for ii in range(len(targets))]
+                        positive_map_label_to_token = all_positive_map_label_to_token[query_i]
+                    else:
+                        captions = None
+                        positive_map_label_to_token = None
+
+                output = im_detect_bbox_aug(model, images, device, captions, positive_map_label_to_token)
+                output = [o.to(cpu_device) for o in output]
+                all_output.append(output)
+            else:
+                images = images.to(device)
+                query_time = len(all_queries)
+
+                for query_i in range(query_time):
+                    if not isinstance(targets[0], dict): # For LVIS dataset and datasets directly copied from MDETR
+                        targets = [target.to(device) for target in targets]
+                    '''
+                    different datasets seem to have different data format... For LVIS dataset, the target is a dictionary, while for modulatedDataset such as COCO/Flickr, the target is a BoxList
+                    '''
+
+                    if task == "detection":
+                        captions = [all_queries[query_i] for ii in range(len(targets))]
+                        positive_map_label_to_token = all_positive_map_label_to_token[query_i]
+                    elif task == "grounding":
+                        captions = [t.get_field("caption") for t in targets]
+                        positive_map_eval = [t.get_field("positive_map_eval") for t in targets]
+                        if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
+                            plus = 1
+                        else:
+                            plus = 0
+                        assert(len(positive_map_eval) == 1) # Let's just use one image per batch
+                        positive_map_eval = positive_map_eval[0]
+                        positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map_eval, plus=plus)
+                    output = model(images, captions=captions, positive_map=positive_map_label_to_token)
+                    output = [o.to(cpu_device) for o in output]
+
+                    if "flickr" in cfg.DATASETS.TEST[0]:
+                        output = output[0]
+                        new_output = flickr_post_process(
+                            output,
+                            targets,
+                            positive_map_label_to_token,
+                            plus # This is only used in Flickr
+                        )
+                        mdetr_style_output.append(new_output)
+                    elif "lvis" in cfg.DATASETS.TEST[0]:
+                        output = output[0]
+                        output = resize_box(output, targets)
+                        scores = output.extra_fields["scores"]
+                        labels = output.extra_fields["labels"]
+                        boxes = output.bbox
+                        mdetr_style_output.append((targets[0]["image_id"].item(), {"scores": scores, "labels": labels, "boxes": boxes}))
+                    else:
+                        all_output.append(output)
+        if visualizer is not None:
+            assert(len(all_output) == 1)
+            if "lvis" in cfg.DATASETS.TEST[0]:
+                scores = [o[1]["scores"] for o in mdetr_style_output]
+                labels = [o[1]["labels"] for o in mdetr_style_output]
+                boxes = [o[1]["boxes"] for o in mdetr_style_output]
+                scores = torch.cat(scores, dim=0)
+                labels = torch.cat(labels, dim=0)
+                boxes = torch.cat(boxes, dim=0)
+                visualizer_input = BoxList(boxes, output.size)
+                visualizer_input.add_field("scores", scores)
+                visualizer_input.add_field("labels", labels)
+            else:
+                visualizer_input = all_output[0][0] # single image_visualize
+
+            image_id = dataset.ids[i]
+            try:
+                image_path = os.path.join(dataset.root, dataset.coco.loadImgs(image_id)[0]["file_name"])
+                categories = dataset.coco.dataset["categories"]
+            except:
+                lvis = dataset.lvis
+                img_id = dataset.ids[i]
+                ann_ids = lvis.get_ann_ids(img_ids=img_id)
+                target = lvis.load_anns(ann_ids)
+
+                image_path = "DATASET/coco/" +  "/".join(dataset.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:])
+                categories = dataset.lvis.dataset["categories"]
+
+            image = load(image_path)
+            no_background = True
+            label_list = []
+            for index, i in enumerate(categories):
+                if not no_background or (i["name"] != "__background__" and i['id'] != 0):
+                    label_list.append(i["name"])
+            visualizer.entities =  label_list
+            
+            result, _ = visualizer.visualize_with_predictions(
+                image,
+                visualizer_input, 
+                threshold,
+                alpha=alpha,
+                box_pixel=box_pixel,
+                text_size=text_size,
+                text_pixel=text_pixel,
+                text_offset=text_offset,
+                text_offset_original=text_offset_original,
+                color=color,
+            )
+            imshow(result, "./visualize/img_{}.jpg".format(i))
+        
+        if evaluator is not None:
+            evaluator.update(mdetr_style_output)
+        else:
+            output = [[row[_i] for row in all_output] for _i in range(len(all_output[0]))]
+            for index, i in enumerate(output):
+                output[index] = i[0].concate_box_list(i)
+
+            results_dict.update({img_id: result for img_id, result in zip(image_ids, output)})
+
+    if evaluator is not None:
+        evaluator.synchronize_between_processes()
+        try:
+            evaluator.accumulate()
+        except:
+            print("Evaluator has no accumulation, skipped...")
+        score = evaluator.summarize()
+        print(score)
+        import maskrcnn_benchmark.utils.mdetr_dist as dist
+        if is_main_process():
+            if "flickr" in cfg.DATASETS.TEST[0]:
+                write_flickr_results(score, output_file_name=os.path.join(output_folder, "bbox.csv"))
+            elif "lvis" in cfg.DATASETS.TEST[0]:
+                write_lvis_results(score, output_file_name=os.path.join(output_folder, "bbox.csv"))
+        try:
+            torch.distributed.barrier()
+        except:
+            print("Default process group is not initialized")
+        return
+
+    if evaluator is not None:
+        predictions = mdetr_style_output
+    else:
+        predictions = results_dict
+    # wait for all processes to complete before measuring the time
+    synchronize()
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=total_time))
+    logger.info(
+        "Total inference time: {} ({} s / img per device, on {} devices)".format(
+            total_time_str, total_time * num_devices / len(dataset), num_devices
+        )
+    )
+
+    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
+    print("Accumulated results")
+    if not is_main_process():
+        return None
+
+    if output_folder:
+        torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
+
+    extra_args = dict(
+        box_only=box_only,
+        iou_types=iou_types,
+        expected_results=expected_results,
+        expected_results_sigma_tol=expected_results_sigma_tol,
+    )
+    return evaluate(dataset=dataset, predictions=predictions, output_folder=output_folder, **extra_args)
diff --git a/maskrcnn_benchmark/engine/predictor.py b/maskrcnn_benchmark/engine/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4362b228f02b4268cd712b82b488f5551511ece
--- /dev/null
+++ b/maskrcnn_benchmark/engine/predictor.py
@@ -0,0 +1,568 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import cv2
+import torch
+import numpy as np
+from torchvision import transforms as T
+
+from maskrcnn_benchmark.modeling.detector import build_detection_model
+from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
+from maskrcnn_benchmark.structures.image_list import to_image_list
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
+from maskrcnn_benchmark import layers as L
+from maskrcnn_benchmark.utils import cv2_util
+
+
+import timeit
+
+class COCODemo(object):
+    # COCO categories for pretty print
+    CATEGORIES = [
+        "__background",
+        "person",
+        "bicycle",
+        "car",
+        "motorcycle",
+        "airplane",
+        "bus",
+        "train",
+        "truck",
+        "boat",
+        "traffic light",
+        "fire hydrant",
+        "stop sign",
+        "parking meter",
+        "bench",
+        "bird",
+        "cat",
+        "dog",
+        "horse",
+        "sheep",
+        "cow",
+        "elephant",
+        "bear",
+        "zebra",
+        "giraffe",
+        "backpack",
+        "umbrella",
+        "handbag",
+        "tie",
+        "suitcase",
+        "frisbee",
+        "skis",
+        "snowboard",
+        "sports ball",
+        "kite",
+        "baseball bat",
+        "baseball glove",
+        "skateboard",
+        "surfboard",
+        "tennis racket",
+        "bottle",
+        "wine glass",
+        "cup",
+        "fork",
+        "knife",
+        "spoon",
+        "bowl",
+        "banana",
+        "apple",
+        "sandwich",
+        "orange",
+        "broccoli",
+        "carrot",
+        "hot dog",
+        "pizza",
+        "donut",
+        "cake",
+        "chair",
+        "couch",
+        "potted plant",
+        "bed",
+        "dining table",
+        "toilet",
+        "tv",
+        "laptop",
+        "mouse",
+        "remote",
+        "keyboard",
+        "cell phone",
+        "microwave",
+        "oven",
+        "toaster",
+        "sink",
+        "refrigerator",
+        "book",
+        "clock",
+        "vase",
+        "scissors",
+        "teddy bear",
+        "hair drier",
+        "toothbrush",
+    ]
+
+    def __init__(
+        self,
+        cfg,
+        confidence_threshold=0.7,
+        show_mask_heatmaps=False,
+        masks_per_dim=2,
+        min_image_size=None,
+        exclude_region=None,
+    ):
+        self.cfg = cfg.clone()
+        self.model = build_detection_model(cfg)
+        self.model.eval()
+        self.device = torch.device(cfg.MODEL.DEVICE)
+        self.model.to(self.device)
+        self.min_image_size = min_image_size
+
+        save_dir = cfg.OUTPUT_DIR
+        checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
+        _ = checkpointer.load(cfg.MODEL.WEIGHT)
+
+        self.transforms = self.build_transform()
+
+        mask_threshold = -1 if show_mask_heatmaps else 0.5
+        self.masker = Masker(threshold=mask_threshold, padding=1)
+
+        # used to make colors for each class
+        self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
+
+        self.cpu_device = torch.device("cpu")
+        self.confidence_threshold = confidence_threshold
+        self.show_mask_heatmaps = show_mask_heatmaps
+        self.masks_per_dim = masks_per_dim
+        self.exclude_region = exclude_region
+
+    def build_transform(self):
+        """
+        Creates a basic transformation that was used to train the models
+        """
+        cfg = self.cfg
+
+        # we are loading images with OpenCV, so we don't need to convert them
+        # to BGR, they are already! So all we need to do is to normalize
+        # by 255 if we want to convert to BGR255 format, or flip the channels
+        # if we want it to be in RGB in [0-1] range.
+        if cfg.INPUT.TO_BGR255:
+            to_bgr_transform = T.Lambda(lambda x: x * 255)
+        else:
+            to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
+
+        normalize_transform = T.Normalize(
+            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
+        )
+
+        transform = T.Compose(
+            [
+                T.ToPILImage(),
+                T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x:x,
+                T.ToTensor(),
+                to_bgr_transform,
+                normalize_transform,
+            ]
+        )
+        return transform
+
+    def inference(self, image, debug=False):
+        """
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+
+        Returns:
+            prediction (BoxList): the detected objects. Additional information
+                of the detection properties can be found in the fields of
+                the BoxList via `prediction.fields()`
+        """
+        predictions, debug_info = self.compute_prediction(image)
+        top_predictions = self.select_top_predictions(predictions)
+
+        if debug:
+            return top_predictions, debug_info
+        else:
+            return top_predictions
+
+    def run_on_opencv_image(self, image):
+        """
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+
+        Returns:
+            prediction (BoxList): the detected objects. Additional information
+                of the detection properties can be found in the fields of
+                the BoxList via `prediction.fields()`
+        """
+        predictions, debug_info = self.compute_prediction(image)
+        top_predictions = self.select_top_predictions(predictions)
+
+        result = image.copy()
+        if self.show_mask_heatmaps:
+            return self.create_mask_montage(result, top_predictions)
+        result = self.overlay_boxes(result, top_predictions)
+        if self.cfg.MODEL.MASK_ON:
+            result = self.overlay_mask(result, top_predictions)
+        if self.cfg.MODEL.KEYPOINT_ON:
+            result = self.overlay_keypoints(result, top_predictions)
+        result = self.overlay_class_names(result, top_predictions)
+
+        return result, debug_info, top_predictions
+
+    def compute_prediction(self, original_image):
+        """
+        Arguments:
+            original_image (np.ndarray): an image as returned by OpenCV
+
+        Returns:
+            prediction (BoxList): the detected objects. Additional information
+                of the detection properties can be found in the fields of
+                the BoxList via `prediction.fields()`
+        """
+        # apply pre-processing to image
+        # if self.exclude_region:
+        #     for region in self.exclude_region:
+        #         original_image[region[1]:region[3], region[0]:region[2], :] = 255
+        image = self.transforms(original_image)
+
+
+        # convert to an ImageList, padded so that it is divisible by
+        # cfg.DATALOADER.SIZE_DIVISIBILITY
+        image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
+        image_list = image_list.to(self.device)
+        tic = timeit.time.perf_counter()
+
+        # compute predictions
+        with torch.no_grad():
+            predictions, debug_info = self.model(image_list)
+        predictions = [o.to(self.cpu_device) for o in predictions]
+        debug_info['total_time'] = timeit.time.perf_counter() - tic
+
+        # always single image is passed at a time
+        prediction = predictions[0]
+
+        # reshape prediction (a BoxList) into the original image size
+        height, width = original_image.shape[:-1]
+        prediction = prediction.resize((width, height))
+
+        if prediction.has_field("mask"):
+            # if we have masks, paste the masks in the right position
+            # in the image, as defined by the bounding boxes
+            masks = prediction.get_field("mask")
+            # always single image is passed at a time
+            masks = self.masker([masks], [prediction])[0]
+            prediction.add_field("mask", masks)
+
+        return prediction, debug_info
+
+    def select_top_predictions(self, predictions):
+        """
+        Select only predictions which have a `score` > self.confidence_threshold,
+        and returns the predictions in descending order of score
+
+        Arguments:
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `scores`.
+
+        Returns:
+            prediction (BoxList): the detected objects. Additional information
+                of the detection properties can be found in the fields of
+                the BoxList via `prediction.fields()`
+        """
+
+        scores = predictions.get_field("scores")
+        labels = predictions.get_field("labels").tolist()
+        thresh = scores.clone()
+        for i,lb in enumerate(labels):
+            if isinstance(self.confidence_threshold, float):
+                thresh[i] = self.confidence_threshold
+            elif len(self.confidence_threshold)==1:
+                thresh[i] = self.confidence_threshold[0]
+            else:
+                thresh[i] = self.confidence_threshold[lb-1]
+        keep = torch.nonzero(scores > thresh).squeeze(1)
+        predictions = predictions[keep]
+
+        if self.exclude_region:
+            exlude = BoxList(self.exclude_region, predictions.size)
+            iou = boxlist_iou(exlude, predictions)
+            keep = torch.nonzero(torch.sum(iou>0.5, dim=0)==0).squeeze(1)
+            if len(keep)>0:
+                predictions = predictions[keep]
+
+        scores = predictions.get_field("scores")
+        _, idx = scores.sort(0, descending=True)
+        return predictions[idx]
+
+    def compute_colors_for_labels(self, labels):
+        """
+        Simple function that adds fixed colors depending on the class
+        """
+        colors = (30*(labels[:, None] -1)+1)*self.palette
+        colors = (colors % 255).numpy().astype("uint8")
+        return colors
+
+    def overlay_boxes(self, image, predictions):
+        """
+        Adds the predicted boxes on top of the image
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `labels`.
+        """
+        labels = predictions.get_field("labels")
+        boxes = predictions.bbox
+
+        colors = self.compute_colors_for_labels(labels).tolist()
+
+        for box, color in zip(boxes, colors):
+            box = box.to(torch.int64)
+            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
+            image = cv2.rectangle(
+                image, tuple(top_left), tuple(bottom_right), tuple(color), 2)
+
+        return image
+
+    def overlay_scores(self, image, predictions):
+        """
+        Adds the predicted boxes on top of the image
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `labels`.
+        """
+        scores = predictions.get_field("scores")
+        boxes = predictions.bbox
+
+        for box, score in zip(boxes, scores):
+            box = box.to(torch.int64)
+            image = cv2.putText(image, '%.3f'%score,
+                                (box[0], (box[1]+box[3])/2),
+                                cv2.FONT_HERSHEY_SIMPLEX, 0.5,
+                                (255,255,255), 1)
+
+        return image
+
+    def overlay_cboxes(self, image, predictions):
+        """
+        Adds the predicted boxes on top of the image
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `labels`.
+        """
+        scores = predictions.get_field("scores")
+        boxes = predictions.bbox
+        for box, score in zip(boxes, scores):
+            box = box.to(torch.int64)
+            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
+            image = cv2.rectangle(
+                image, tuple(top_left), tuple(bottom_right), (255,0,0), 2)
+            image = cv2.putText(image, '%.3f'%score,
+                                (box[0], (box[1]+box[3])/2),
+                                cv2.FONT_HERSHEY_SIMPLEX, 0.5,
+                                (255,0,0), 1)
+        return image
+
+    def overlay_centers(self, image, predictions):
+        """
+        Adds the predicted boxes on top of the image
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `labels`.
+        """
+        centers = predictions.get_field("centers")
+
+        for cord in centers:
+            cord = cord.to(torch.int64)
+            image = cv2.circle(image, (cord[0].item(),cord[1].item()),
+                               2, (255,0,0), 20)
+
+        return image
+
+    def overlay_count(self, image, predictions):
+        """
+        Adds the predicted boxes on top of the image
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `labels`.
+        """
+        if isinstance(predictions, int):
+            count = predictions
+        else:
+            count = len(predictions)
+        image = cv2.putText(image, 'Count: %d'%count, (0,100), cv2.FONT_HERSHEY_SIMPLEX, 3,  (255,0,0), 3)
+
+        return image
+
+    def overlay_mask(self, image, predictions):
+        """
+        Adds the instances contours for each predicted object.
+        Each label has a different color.
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `mask` and `labels`.
+        """
+        masks = predictions.get_field("mask").numpy()
+        labels = predictions.get_field("labels")
+
+        colors = self.compute_colors_for_labels(labels).tolist()
+
+        for mask, color in zip(masks, colors):
+            thresh = mask[0, :, :, None].astype(np.uint8)
+            contours, hierarchy = cv2_util.findContours(
+                thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+            )
+            image = cv2.drawContours(image, contours, -1, color, 3)
+
+        composite = image
+
+        return composite
+
+    def overlay_keypoints(self, image, predictions):
+        keypoints = predictions.get_field("keypoints")
+        kps = keypoints.keypoints
+        scores = keypoints.get_field("logits")
+        kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy()
+        for region in kps:
+            image = vis_keypoints(image, region.transpose((1, 0)),
+                                  names=keypoints.NAMES, connections=keypoints.CONNECTIONS)
+        return image
+
+    def create_mask_montage(self, image, predictions):
+        """
+        Create a montage showing the probability heatmaps for each one one of the
+        detected objects
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `mask`.
+        """
+        masks = predictions.get_field("mask")
+        masks_per_dim = self.masks_per_dim
+        masks = L.interpolate(
+            masks.float(), scale_factor=1 / masks_per_dim
+        ).byte()
+        height, width = masks.shape[-2:]
+        max_masks = masks_per_dim ** 2
+        masks = masks[:max_masks]
+        # handle case where we have less detections than max_masks
+        if len(masks) < max_masks:
+            masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
+            masks_padded[: len(masks)] = masks
+            masks = masks_padded
+        masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
+        result = torch.zeros(
+            (masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
+        )
+        for y in range(masks_per_dim):
+            start_y = y * height
+            end_y = (y + 1) * height
+            for x in range(masks_per_dim):
+                start_x = x * width
+                end_x = (x + 1) * width
+                result[start_y:end_y, start_x:end_x] = masks[y, x]
+        return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET)
+
+    def overlay_class_names(self, image, predictions, names=None):
+        """
+        Adds detected class names and scores in the positions defined by the
+        top-left corner of the predicted bounding box
+
+        Arguments:
+            image (np.ndarray): an image as returned by OpenCV
+            predictions (BoxList): the result of the computation by the model.
+                It should contain the field `scores` and `labels`.
+        """
+        scores = predictions.get_field("scores").tolist()
+        labels = predictions.get_field("labels").tolist()
+        if names:
+            labels = [names[i-1] for i in labels]
+        else:
+            labels = [self.CATEGORIES[i] for i in labels]
+        boxes = predictions.bbox
+
+        template = "{}: {:.2f}"
+        for box, score, label in zip(boxes, scores, labels):
+            x, y = box[:2]
+            s = template.format(label, score)
+            cv2.putText(
+                image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
+            )
+
+        return image
+
+def vis_keypoints(img, kps, kp_thresh=0, alpha=0.7, names=None, connections=None):
+    """Visualizes keypoints (adapted from vis_one_image).
+    kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
+    """
+
+    dataset_keypoints = names
+    kp_lines = connections
+
+    # simple rainbow color map implementation
+    blue_red_ratio = 0.8
+    gx = lambda x: (6-2*blue_red_ratio)*x + blue_red_ratio
+    colors = [[256*max(0, (3-abs(gx(i)-4)-abs(gx(i)-5))/2),
+               256*max(0, (3-abs(gx(i)-2)-abs(gx(i)-4))/2),
+               256*max(0, (3-abs(gx(i)-1)-abs(gx(i)-2))/2),] for i in np.linspace(0, 1, len(kp_lines) + 2)]
+
+    # Perform the drawing on a copy of the image, to allow for blending.
+    kp_mask = np.copy(img)
+
+    # Draw mid shoulder / mid hip first for better visualization.
+    mid_shoulder = (
+        kps[:2, dataset_keypoints.index('right_shoulder')] +
+        kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
+    sc_mid_shoulder = np.minimum(
+        kps[2, dataset_keypoints.index('right_shoulder')],
+        kps[2, dataset_keypoints.index('left_shoulder')])
+    nose_idx = dataset_keypoints.index('nose')
+    if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh:
+        cv2.line(
+            kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]),
+            color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA)
+
+    if 'right_hip' in names and 'left_hip' in names:
+        mid_hip = (
+            kps[:2, dataset_keypoints.index('right_hip')] +
+            kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
+        sc_mid_hip = np.minimum(
+            kps[2, dataset_keypoints.index('right_hip')],
+            kps[2, dataset_keypoints.index('left_hip')])
+        if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
+            cv2.line(
+                kp_mask, tuple(mid_shoulder), tuple(mid_hip),
+                color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA)
+
+    # Draw the keypoints.
+    for l in range(len(kp_lines)):
+        i1 = kp_lines[l][0]
+        i2 = kp_lines[l][1]
+        p1 = kps[0, i1], kps[1, i1]
+        p2 = kps[0, i2], kps[1, i2]
+        if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
+            cv2.line(
+                kp_mask, p1, p2,
+                color=colors[l], thickness=2, lineType=cv2.LINE_AA)
+        if kps[2, i1] > kp_thresh:
+            cv2.circle(
+                kp_mask, p1,
+                radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
+        if kps[2, i2] > kp_thresh:
+            cv2.circle(
+                kp_mask, p2,
+                radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
+
+    # Blend the keypoints.
+    return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/engine/predictor_glip.py b/maskrcnn_benchmark/engine/predictor_glip.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdfcc24b6abf711a77217d03c83bad7d6c6f442
--- /dev/null
+++ b/maskrcnn_benchmark/engine/predictor_glip.py
@@ -0,0 +1,471 @@
+import cv2
+import torch
+import re
+import numpy as np
+from typing import List, Union
+import nltk
+import inflect
+from transformers import AutoTokenizer
+from torchvision import transforms as T
+import pdb
+from maskrcnn_benchmark.modeling.detector import build_detection_model
+from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
+from maskrcnn_benchmark.structures.image_list import to_image_list
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark import layers as L
+from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
+from maskrcnn_benchmark.utils import cv2_util
+
+engine = inflect.engine()
+nltk.download('punkt')
+nltk.download('averaged_perceptron_tagger')
+
+import timeit
+
+
+class GLIPDemo(object):
+    def __init__(self,
+                 cfg,
+                 confidence_threshold=0.7,
+                 min_image_size=None,
+                 show_mask_heatmaps=False,
+                 masks_per_dim=5,
+                 load_model=True
+                 ):
+        self.cfg = cfg.clone()
+        if load_model:
+            self.model = build_detection_model(cfg)
+            self.model.eval()
+            self.device = torch.device(cfg.MODEL.DEVICE)
+            self.model.to(self.device)
+        self.min_image_size = min_image_size
+        self.show_mask_heatmaps = show_mask_heatmaps
+        self.masks_per_dim = masks_per_dim
+
+        save_dir = cfg.OUTPUT_DIR
+        if load_model:
+            checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
+            _ = checkpointer.load(cfg.MODEL.WEIGHT)
+
+        self.transforms = self.build_transform()
+
+        # used to make colors for each tokens
+        mask_threshold = -1 if show_mask_heatmaps else 0.5
+        self.masker = Masker(threshold=mask_threshold, padding=1)
+        self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
+        self.cpu_device = torch.device("cpu")
+        self.confidence_threshold = confidence_threshold
+
+        self.tokenizer = self.build_tokenizer()
+
+    def build_transform(self):
+        """
+        Creates a basic transformation that was used to train the models
+        """
+        cfg = self.cfg
+
+        # we are loading images with OpenCV, so we don't need to convert them
+        # to BGR, they are already! So all we need to do is to normalize
+        # by 255 if we want to convert to BGR255 format, or flip the channels
+        # if we want it to be in RGB in [0-1] range.
+        if cfg.INPUT.TO_BGR255:
+            to_bgr_transform = T.Lambda(lambda x: x * 255)
+        else:
+            to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
+
+        normalize_transform = T.Normalize(
+            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
+        )
+
+        transform = T.Compose(
+            [
+                T.ToPILImage(),
+                T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x,
+                T.ToTensor(),
+                to_bgr_transform,
+                normalize_transform,
+            ]
+        )
+        return transform
+
+    def build_tokenizer(self):
+        cfg = self.cfg
+        tokenizer = None
+        if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
+            tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+        elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+            from transformers import CLIPTokenizerFast
+            if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+                tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                              from_slow=True, mask_token='ðŁĴij</w>')
+            else:
+                tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                              from_slow=True)
+        return tokenizer
+
+    def run_ner(self, caption):
+        noun_phrases = find_noun_phrases(caption)
+        noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
+        noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
+        relevant_phrases = noun_phrases
+        labels = noun_phrases
+        self.entities = labels
+
+        tokens_positive = []
+
+        for entity, label in zip(relevant_phrases, labels):
+            try:
+                # search all occurrences and mark them as different entities
+                for m in re.finditer(entity, caption.lower()):
+                    tokens_positive.append([[m.start(), m.end()]])
+            except:
+                print("noun entities:", noun_phrases)
+                print("entity:", entity)
+                print("caption:", caption.lower())
+
+        return tokens_positive
+
+    def inference(self, original_image, original_caption):
+        predictions = self.compute_prediction(original_image, original_caption)
+        top_predictions = self._post_process_fixed_thresh(predictions)
+        return top_predictions
+
+    def run_on_web_image(self,
+                         original_image,
+                         original_caption,
+                         thresh=0.5,
+                         custom_entity=None,
+                         alpha=0.0):
+        predictions = self.compute_prediction(original_image, original_caption, custom_entity)
+        top_predictions = self._post_process(predictions, thresh)
+
+        result = original_image.copy()
+        if self.show_mask_heatmaps:
+            return self.create_mask_montage(result, top_predictions)
+        result = self.overlay_boxes(result, top_predictions)
+        result = self.overlay_entity_names(result, top_predictions)
+        if self.cfg.MODEL.MASK_ON:
+            result = self.overlay_mask(result, top_predictions)
+        return result, top_predictions
+
+    def visualize_with_predictions(self,
+                                   original_image,
+                                   predictions,
+                                   thresh=0.5,
+                                   alpha=0.0,
+                                   box_pixel=3,
+                                   text_size=1,
+                                   text_pixel=2,
+                                   text_offset=10,
+                                   text_offset_original=4,
+                                   color=255):
+        self.color = color
+        height, width = original_image.shape[:-1]
+        predictions = predictions.resize((width, height))
+        top_predictions = self._post_process(predictions, thresh)
+
+        result = original_image.copy()
+        if self.show_mask_heatmaps:
+            return self.create_mask_montage(result, top_predictions)
+        result = self.overlay_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel)
+        result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel,
+                                           text_offset=text_offset, text_offset_original=text_offset_original)
+        if self.cfg.MODEL.MASK_ON:
+            result = self.overlay_mask(result, top_predictions)
+        return result, top_predictions
+
+    def compute_prediction(self, original_image, original_caption, custom_entity=None):
+        # image
+        image = self.transforms(original_image)
+        image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
+        image_list = image_list.to(self.device)
+        # caption
+        if isinstance(original_caption, list):
+            # we directly provided a list of category names
+            caption_string = ""
+            tokens_positive = []
+            seperation_tokens = " . "
+            for word in original_caption:
+                tokens_positive.append([len(caption_string), len(caption_string) + len(word)])
+                caption_string += word
+                caption_string += seperation_tokens
+
+            tokenized = self.tokenizer([caption_string], return_tensors="pt")
+            tokens_positive = [tokens_positive]
+
+            original_caption = caption_string
+            print(tokens_positive)
+        else:
+            tokenized = self.tokenizer([original_caption], return_tensors="pt")
+            if custom_entity is None:
+                tokens_positive = self.run_ner(original_caption)
+            print(tokens_positive)
+        # process positive map
+        positive_map = create_positive_map(tokenized, tokens_positive)
+
+        if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
+            plus = 1
+        else:
+            plus = 0
+
+        positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus)
+        self.plus = plus
+        self.positive_map_label_to_token = positive_map_label_to_token
+        tic = timeit.time.perf_counter()
+
+        # compute predictions
+        with torch.no_grad():
+            predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token)
+            predictions = [o.to(self.cpu_device) for o in predictions]
+        print("inference time per image: {}".format(timeit.time.perf_counter() - tic))
+
+        # always single image is passed at a time
+        prediction = predictions[0]
+
+        # reshape prediction (a BoxList) into the original image size
+        height, width = original_image.shape[:-1]
+        prediction = prediction.resize((width, height))
+
+        if prediction.has_field("mask"):
+            # if we have masks, paste the masks in the right position
+            # in the image, as defined by the bounding boxes
+            masks = prediction.get_field("mask")
+            # always single image is passed at a time
+            masks = self.masker([masks], [prediction])[0]
+            prediction.add_field("mask", masks)
+
+        return prediction
+
+    def _post_process_fixed_thresh(self, predictions):
+        scores = predictions.get_field("scores")
+        labels = predictions.get_field("labels").tolist()
+        thresh = scores.clone()
+        for i, lb in enumerate(labels):
+            if isinstance(self.confidence_threshold, float):
+                thresh[i] = self.confidence_threshold
+            elif len(self.confidence_threshold) == 1:
+                thresh[i] = self.confidence_threshold[0]
+            else:
+                thresh[i] = self.confidence_threshold[lb - 1]
+        keep = torch.nonzero(scores > thresh).squeeze(1)
+        predictions = predictions[keep]
+
+        scores = predictions.get_field("scores")
+        _, idx = scores.sort(0, descending=True)
+        return predictions[idx]
+
+    def _post_process(self, predictions, threshold=0.5):
+        scores = predictions.get_field("scores")
+        labels = predictions.get_field("labels").tolist()
+        thresh = scores.clone()
+        for i, lb in enumerate(labels):
+            if isinstance(self.confidence_threshold, float):
+                thresh[i] = threshold
+            elif len(self.confidence_threshold) == 1:
+                thresh[i] = threshold
+            else:
+                thresh[i] = self.confidence_threshold[lb - 1]
+        keep = torch.nonzero(scores > thresh).squeeze(1)
+        predictions = predictions[keep]
+
+        scores = predictions.get_field("scores")
+        _, idx = scores.sort(0, descending=True)
+        return predictions[idx]
+
+    def compute_colors_for_labels(self, labels):
+        """
+        Simple function that adds fixed colors depending on the class
+        """
+        colors = (300 * (labels[:, None] - 1) + 1) * self.palette
+        colors = (colors % 255).numpy().astype("uint8")
+        try:
+            colors = (colors * 0 + self.color).astype("uint8")
+        except:
+            pass
+        return colors
+
+    def overlay_boxes(self, image, predictions, alpha=0.5, box_pixel=3):
+        labels = predictions.get_field("labels")
+        boxes = predictions.bbox
+
+        colors = self.compute_colors_for_labels(labels).tolist()
+        new_image = image.copy()
+        for box, color in zip(boxes, colors):
+            box = box.to(torch.int64)
+            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
+            new_image = cv2.rectangle(
+                new_image, tuple(top_left), tuple(bottom_right), tuple(color), box_pixel)
+
+        # Following line overlays transparent rectangle over the image
+        image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0)
+
+        return image
+
+    def overlay_scores(self, image, predictions):
+        scores = predictions.get_field("scores")
+        boxes = predictions.bbox
+
+        for box, score in zip(boxes, scores):
+            box = box.to(torch.int64)
+            image = cv2.putText(image, '%.3f' % score,
+                                (int(box[0]), int((box[1] + box[3]) / 2)),
+                                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
+
+        return image
+
+    def overlay_entity_names(self, image, predictions, names=None, text_size=0.7, text_pixel=2, text_offset=10,
+                             text_offset_original=4):
+        scores = predictions.get_field("scores").tolist()
+        labels = predictions.get_field("labels").tolist()
+        new_labels = []
+        if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
+            plus = 1
+        else:
+            plus = 0
+        self.plus = plus
+        if self.entities and self.plus:
+            for i in labels:
+                if i <= len(self.entities):
+                    new_labels.append(self.entities[i - self.plus])
+                else:
+                    new_labels.append('object')
+            # labels = [self.entities[i - self.plus] for i in labels ]
+        else:
+            new_labels = ['object' for i in labels]
+        boxes = predictions.bbox
+
+        template = "{}:{:.2f}"
+        previous_locations = []
+        for box, score, label in zip(boxes, scores, new_labels):
+            x, y = box[:2]
+            s = template.format(label, score).replace("_", " ").replace("(", "").replace(")", "")
+            for x_prev, y_prev in previous_locations:
+                if abs(x - x_prev) < abs(text_offset) and abs(y - y_prev) < abs(text_offset):
+                    y -= text_offset
+
+            cv2.putText(
+                image, s, (int(x), int(y) - text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size,
+                (255, 255, 255), text_pixel, cv2.LINE_AA
+            )
+            previous_locations.append((int(x), int(y)))
+
+        return image
+
+    def overlay_mask(self, image, predictions):
+        masks = predictions.get_field("mask").numpy()
+        labels = predictions.get_field("labels")
+
+        colors = self.compute_colors_for_labels(labels).tolist()
+
+        # import pdb
+        # pdb.set_trace()
+        # masks = masks > 0.1
+
+        for mask, color in zip(masks, colors):
+            thresh = mask[0, :, :, None].astype(np.uint8)
+            contours, hierarchy = cv2_util.findContours(
+                thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+            )
+            image = cv2.drawContours(image, contours, -1, color, 2)
+
+        composite = image
+
+        return composite
+
+    def create_mask_montage(self, image, predictions):
+        masks = predictions.get_field("mask")
+        masks_per_dim = self.masks_per_dim
+        masks = L.interpolate(
+            masks.float(), scale_factor=1 / masks_per_dim
+        ).byte()
+        height, width = masks.shape[-2:]
+        max_masks = masks_per_dim ** 2
+        masks = masks[:max_masks]
+        # handle case where we have less detections than max_masks
+        if len(masks) < max_masks:
+            masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
+            masks_padded[: len(masks)] = masks
+            masks = masks_padded
+        masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
+        result = torch.zeros(
+            (masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
+        )
+        for y in range(masks_per_dim):
+            start_y = y * height
+            end_y = (y + 1) * height
+            for x in range(masks_per_dim):
+                start_x = x * width
+                end_x = (x + 1) * width
+                result[start_y:end_y, start_x:end_x] = masks[y, x]
+
+        return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None
+
+
+def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0):
+    positive_map_label_to_token = {}
+    for i in range(len(positive_map)):
+        positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
+    return positive_map_label_to_token
+
+
+def create_positive_map(tokenized, tokens_positive):
+    """construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
+    positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
+
+    for j, tok_list in enumerate(tokens_positive):
+        for (beg, end) in tok_list:
+            try:
+                beg_pos = tokenized.char_to_token(beg)
+                end_pos = tokenized.char_to_token(end - 1)
+            except Exception as e:
+                print("beg:", beg, "end:", end)
+                print("token_positive:", tokens_positive)
+                # print("beg_pos:", beg_pos, "end_pos:", end_pos)
+                raise e
+            if beg_pos is None:
+                try:
+                    beg_pos = tokenized.char_to_token(beg + 1)
+                    if beg_pos is None:
+                        beg_pos = tokenized.char_to_token(beg + 2)
+                except:
+                    beg_pos = None
+            if end_pos is None:
+                try:
+                    end_pos = tokenized.char_to_token(end - 2)
+                    if end_pos is None:
+                        end_pos = tokenized.char_to_token(end - 3)
+                except:
+                    end_pos = None
+            if beg_pos is None or end_pos is None:
+                continue
+
+            assert beg_pos is not None and end_pos is not None
+            positive_map[j, beg_pos: end_pos + 1].fill_(1)
+    return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
+
+
+def find_noun_phrases(caption: str) -> List[str]:
+    caption = caption.lower()
+    tokens = nltk.word_tokenize(caption)
+    pos_tags = nltk.pos_tag(tokens)
+
+    grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
+    cp = nltk.RegexpParser(grammar)
+    result = cp.parse(pos_tags)
+
+    noun_phrases = list()
+    for subtree in result.subtrees():
+        if subtree.label() == 'NP':
+            noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
+
+    return noun_phrases
+
+
+def remove_punctuation(text: str) -> str:
+    punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^',
+             '\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
+             ]
+    for p in punct:
+        text = text.replace(p, '')
+    return text.strip()
diff --git a/maskrcnn_benchmark/engine/singlepath_trainer.py b/maskrcnn_benchmark/engine/singlepath_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..303f151d11e5d2a4ab7c849cfb213e175f3469c4
--- /dev/null
+++ b/maskrcnn_benchmark/engine/singlepath_trainer.py
@@ -0,0 +1,141 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import datetime
+import logging
+import time
+import random
+import torch
+import torch.distributed as dist
+from maskrcnn_benchmark.utils.comm import get_world_size, synchronize, broadcast_data
+from maskrcnn_benchmark.utils.metric_logger import MetricLogger
+from maskrcnn_benchmark.utils.ema import ModelEma
+
+
+def reduce_loss_dict(loss_dict):
+    """
+    Reduce the loss dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    loss_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return loss_dict
+    with torch.no_grad():
+        loss_names = []
+        all_losses = []
+        for k in sorted(loss_dict.keys()):
+            loss_names.append(k)
+            all_losses.append(loss_dict[k])
+        all_losses = torch.stack(all_losses, dim=0)
+        dist.reduce(all_losses, dst=0)
+        if dist.get_rank() == 0:
+            # only main process gets accumulated, so only divide by
+            # world_size in this case
+            all_losses /= world_size
+        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
+    return reduced_losses
+
+
+def do_train(
+        cfg,
+        model,
+        data_loader,
+        optimizer,
+        scheduler,
+        checkpointer,
+        device,
+        checkpoint_period,
+        arguments,
+        rngs=None
+):
+    logger = logging.getLogger("maskrcnn_benchmark.trainer")
+    logger.info("Start training")
+    meters = MetricLogger(delimiter="  ")
+    max_iter = len(data_loader)
+    start_iter = arguments["iteration"]
+    model.train()
+    model_ema = None
+    if cfg.SOLVER.MODEL_EMA>0:
+        model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
+    start_training_time = time.time()
+    end = time.time()
+
+    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
+
+        if any(len(target) < 1 for target in targets):
+            logger.error("Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
+            continue
+        data_time = time.time() - end
+        iteration = iteration + 1
+        arguments["iteration"] = iteration
+
+        images = images.to(device)
+        targets = [target.to(device) for target in targets]
+
+        # synchronize rngs
+        if rngs is None:
+            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+                mix_nums = model.module.mix_nums
+            else:
+                mix_nums = model.mix_nums
+            rngs = [random.randint(0, mix-1) for mix in mix_nums]
+        rngs = broadcast_data(rngs)
+
+        for param in model.parameters():
+            param.requires_grad = False
+        loss_dict = model(images, targets, rngs)
+
+        losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = reduce_loss_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+        meters.update(loss=losses_reduced, **loss_dict_reduced)
+
+        optimizer.zero_grad()
+        losses.backward()
+        optimizer.step()
+        scheduler.step()
+
+        if model_ema is not None:
+            model_ema.update(model)
+            arguments["model_ema"] = model_ema.state_dict()
+
+        batch_time = time.time() - end
+        end = time.time()
+        meters.update(time=batch_time, data=data_time)
+
+        eta_seconds = meters.time.global_avg * (max_iter - iteration)
+        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+        if iteration % 20 == 0 or iteration == max_iter:
+            logger.info(
+                meters.delimiter.join(
+                    [
+                        "eta: {eta}",
+                        "iter: {iter}",
+                        "{meters}",
+                        "lr: {lr:.6f}",
+                        "max mem: {memory:.0f}",
+                    ]
+                ).format(
+                    eta=eta_string,
+                    iter=iteration,
+                    meters=str(meters),
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+                )
+            )
+        if iteration % checkpoint_period == 0:
+            checkpointer.save("model_{:07d}".format(iteration), **arguments)
+        if iteration == max_iter:
+            if model_ema is not None:
+                model.load_state_dict(model_ema.state_dict())
+            checkpointer.save("model_final", **arguments)
+
+    total_training_time = time.time() - start_training_time
+    total_time_str = str(datetime.timedelta(seconds=total_training_time))
+    logger.info(
+        "Total training time: {} ({:.4f} s / it)".format(
+            total_time_str, total_training_time / (max_iter)
+        )
+    )
diff --git a/maskrcnn_benchmark/engine/stage_trainer.py b/maskrcnn_benchmark/engine/stage_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c925d48bb8fae7ac76afd18bc5ea23a9491827c
--- /dev/null
+++ b/maskrcnn_benchmark/engine/stage_trainer.py
@@ -0,0 +1,184 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import datetime
+import logging
+import time
+
+import torch
+import torch.distributed as dist
+
+from maskrcnn_benchmark.utils.comm import get_world_size
+from maskrcnn_benchmark.utils.metric_logger import MetricLogger
+
+
+def reduce_loss_dict(all_loss_dict):
+    """
+    Reduce the loss dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    loss_dict, after reduction.
+    """
+    world_size = get_world_size()
+    with torch.no_grad():
+        loss_names = []
+        all_losses = []
+        for loss_dict in all_loss_dict:
+            for k in sorted(loss_dict.keys()):
+                loss_names.append(k)
+                all_losses.append(loss_dict[k])
+        all_losses = torch.stack(all_losses, dim=0)
+        if world_size > 1:
+            dist.reduce(all_losses, dst=0)
+            if dist.get_rank() == 0:
+                # only main process gets accumulated, so only divide by
+                # world_size in this case
+                all_losses /= world_size
+
+        reduced_losses = {}
+        for k, v in zip(loss_names, all_losses):
+            if k not in reduced_losses:
+                reduced_losses[k] = v / len(all_loss_dict)
+            reduced_losses[k] += v / len(all_loss_dict)
+
+    return reduced_losses
+
+
+def do_train(
+        model,
+        data_loader,
+        optimizer,
+        scheduler,
+        checkpointer,
+        device,
+        checkpoint_period,
+        arguments,
+):
+    logger = logging.getLogger("maskrcnn_benchmark.trainer")
+    logger.info("Start training")
+    meters = MetricLogger(delimiter="  ")
+    epoch_per_stage = arguments['epoch_per_stage']
+    max_iter = sum(len(stage_loader) * epoch_per_stage[si] for si, stage_loader in enumerate(data_loader))
+    max_iter += epoch_per_stage[-1] * min(len(stage_loader) for stage_loader in data_loader)
+    model.train()
+    start_training_time = time.time()
+    end = time.time()
+
+    for stage_i, stage_loader in enumerate(data_loader):
+        for ep in range(epoch_per_stage[stage_i]):
+            start_iter = arguments["iteration"]
+            for iteration, (images, targets, _) in enumerate(stage_loader, start_iter):
+                data_time = time.time() - end
+                iteration = iteration + 1
+                arguments["iteration"] = iteration
+
+                scheduler[stage_i].step()
+
+                all_stage_loss_dict = []
+                images = images.to(device)
+                targets = [target.to(device) for target in targets]
+                loss_dict = model(images, targets, stage_i)
+                all_stage_loss_dict.append(loss_dict)
+
+                losses = sum(loss for loss_dict in all_stage_loss_dict for loss in loss_dict.values())
+
+                # reduce losses over all GPUs for logging purposes
+                loss_dict_reduced = reduce_loss_dict(all_stage_loss_dict)
+                losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+                meters.update(loss=losses_reduced, **loss_dict_reduced)
+
+                optimizer.zero_grad()
+                losses.backward()
+                optimizer.step()
+
+                batch_time = time.time() - end
+                end = time.time()
+                meters.update(time=batch_time, data=data_time)
+
+                eta_seconds = meters.time.global_avg * (max_iter - iteration)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+                if iteration % 20 == 0 or iteration == max_iter:
+                    logger.info(
+                        meters.delimiter.join(
+                            [
+                                "eta: {eta}",
+                                "iter: {iter}",
+                                "{meters}",
+                                "lr: {lr:.6f}",
+                                "max mem: {memory:.0f}",
+                            ]
+                        ).format(
+                            eta=eta_string,
+                            iter=iteration,
+                            meters=str(meters),
+                            lr=optimizer.param_groups[0]["lr"],
+                            memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+                        )
+                    )
+                if iteration % checkpoint_period == 0:
+                    checkpointer.save("model_{:07d}".format(iteration), **arguments)
+                if iteration == max_iter:
+                    checkpointer.save("model_final", **arguments)
+
+    for ep in range(epoch_per_stage[-1]):
+        start_iter = arguments["iteration"]
+        for iteration, stage_loader in enumerate(zip(*data_loader), start_iter):
+            data_time = time.time() - end
+            iteration = iteration + 1
+            arguments["iteration"] = iteration
+
+            scheduler[-1].step()
+
+            all_task_loss_dict = []
+            for stage_i, (images, targets, _) in enumerate(stage_loader):
+                images = images.to(device)
+                targets = [target.to(device) for target in targets]
+                loss_dict = model(images, targets, stage_i)
+                all_task_loss_dict.append(loss_dict)
+
+            losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values())
+
+            # reduce losses over all GPUs for logging purposes
+            loss_dict_reduced = reduce_loss_dict(all_task_loss_dict)
+            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+            meters.update(loss=losses_reduced, **loss_dict_reduced)
+
+            optimizer.zero_grad()
+            losses.backward()
+            optimizer.step()
+
+            batch_time = time.time() - end
+            end = time.time()
+            meters.update(time=batch_time, data=data_time)
+
+            eta_seconds = meters.time.global_avg * (max_iter - iteration)
+            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+            if iteration % 20 == 0 or iteration == max_iter:
+                logger.info(
+                    meters.delimiter.join(
+                        [
+                            "eta: {eta}",
+                            "iter: {iter}",
+                            "{meters}",
+                            "lr: {lr:.6f}",
+                            "max mem: {memory:.0f}",
+                        ]
+                    ).format(
+                        eta=eta_string,
+                        iter=iteration,
+                        meters=str(meters),
+                        lr=optimizer.param_groups[0]["lr"],
+                        memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+                    )
+                )
+            if iteration % checkpoint_period == 0:
+                checkpointer.save("model_{:07d}".format(iteration), **arguments)
+            if iteration == max_iter:
+                checkpointer.save("model_final", **arguments)
+
+    total_training_time = time.time() - start_training_time
+    total_time_str = str(datetime.timedelta(seconds=total_training_time))
+    logger.info(
+        "Total training time: {} ({:.4f} s / it)".format(
+            total_time_str, total_training_time / (max_iter)
+        )
+    )
diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939a0443ba5612303087d750d49d724c890855c
--- /dev/null
+++ b/maskrcnn_benchmark/engine/trainer.py
@@ -0,0 +1,360 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import datetime
+import logging
+import sys
+import os
+import math
+import time
+
+import torch
+import torch.distributed as dist
+
+from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank
+from maskrcnn_benchmark.utils.metric_logger import MetricLogger
+from maskrcnn_benchmark.utils.ema import ModelEma
+from maskrcnn_benchmark.utils.amp import autocast, GradScaler
+from maskrcnn_benchmark.data.datasets.evaluation import evaluate
+from .inference import inference
+import pdb
+
+def reduce_loss_dict(loss_dict):
+    """
+    Reduce the loss dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    loss_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return loss_dict
+    with torch.no_grad():
+        loss_names = []
+        all_losses = []
+        for k in sorted(loss_dict.keys()):
+            loss_names.append(k)
+            all_losses.append(loss_dict[k])
+        all_losses = torch.stack(all_losses, dim=0)
+        dist.reduce(all_losses, dst=0)
+        if dist.get_rank() == 0:
+            # only main process gets accumulated, so only divide by
+            # world_size in this case
+            all_losses /= world_size
+        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
+    return reduced_losses
+
+
+def do_train(
+        cfg,
+        model,
+        data_loader,
+        optimizer,
+        scheduler,
+        checkpointer,
+        device,
+        checkpoint_period,
+        arguments,
+        val_data_loader=None,
+        meters=None,
+        zero_shot=False
+):
+    logger = logging.getLogger("maskrcnn_benchmark.trainer")
+    logger.info("Start training")
+    # meters = MetricLogger(delimiter="  ")
+    max_iter = len(data_loader)
+    start_iter = arguments["iteration"]
+    model.train()
+    model_ema = None
+    if cfg.SOLVER.MODEL_EMA > 0:
+        model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
+    start_training_time = time.time()
+    end = time.time()
+
+    if cfg.SOLVER.USE_AMP:
+        scaler = GradScaler()
+
+    global_rank = get_rank()
+
+    if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1:
+        checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH
+    
+    if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1:
+        print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH )
+
+    if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
+        patience_counter = 0
+        previous_best = 0.0
+
+    # Adapt the weight decay
+    if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
+        milestone_target = 0
+        for i, milstone in enumerate(list(scheduler.milestones)):
+            if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
+                milestone_target = i+1
+    for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
+        nnegative = sum(len(target) < 1 for target in targets)
+        nsample = len(targets)
+        if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH:
+            logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'.
+                        format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH))
+            continue
+
+        data_time = time.time() - end
+        iteration = iteration + 1
+        arguments["iteration"] = iteration
+
+        images = images.to(device)
+        captions = None
+        try:
+            targets = [target.to(device) for target in targets]
+            captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
+        except:
+            pass
+        # Freeze language backbone
+        if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
+            if hasattr(model, "module"):
+                model.module.language_backbone.eval()
+            else:
+                model.language_backbone.eval()
+
+        if cfg.SOLVER.USE_AMP:
+            with autocast():
+                if len(captions) > 0:
+                    loss_dict = model(images, targets, captions, positive_map, greenlight_map = greenlight_map)
+                else:
+                    loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+            # save checkpoints for further debug if nan happens
+            # loss_value = losses.item()
+            # if not math.isfinite(loss_value):
+            #     logging.error(f'=> loss is {loss_value}, stopping training')
+            #     logging.error("Losses are : {}".format(loss_dict))
+            #     time_str = time.strftime('%Y-%m-%d-%H-%M')
+            #     fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
+            #     logging.info(f'=> save error state to {fname}')
+            #     dict_to_save = {
+            #         'x': images,
+            #         'y': targets,
+            #         'loss': losses,
+            #         'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
+            #     }
+            #     if len(captions) > 0:
+            #         dict_to_save['captions'] = captions
+            #         dict_to_save['positive_map'] = positive_map
+            #     torch.save(
+            #             dict_to_save,
+            #             fname
+            #         )
+
+
+            if torch.isnan(losses) or torch.isinf(losses):
+                logging.error("NaN encountered, ignoring")
+                losses[losses != losses] = 0
+            optimizer.zero_grad()
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+            scheduler.step()
+        else:
+            if len(captions) > 0:
+                loss_dict = model(images, targets, captions, positive_map)
+            else:
+                loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+            # loss_value = losses.item()
+            # if not math.isfinite(loss_value):
+            #     logging.error(f'=> loss is {loss_value}, stopping training')
+            #     time_str = time.strftime('%Y-%m-%d-%H-%M')
+            #     fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
+            #     logging.info(f'=> save error state to {fname}')
+            #     dict_to_save = {
+            #         'x': images,
+            #         'y': targets,
+            #         'loss': losses,
+            #         'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
+            #     }
+            #     if len(captions) > 0:
+            #         dict_to_save['captions'] = captions
+            #         dict_to_save['positive_map'] = positive_map
+            #     torch.save(
+            #         dict_to_save,
+            #         fname
+            #     )
+                
+
+            if torch.isnan(losses) or torch.isinf(losses):
+                losses[losses != losses] = 0
+            optimizer.zero_grad()
+            losses.backward()
+            optimizer.step()
+            scheduler.step()
+
+        # Adapt the weight decay: only support multiStepLR
+        if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
+            if milestone_target < len(scheduler.milestones):
+                next_milestone = list(scheduler.milestones)[milestone_target]
+            else:
+                next_milestone = float('inf')
+            if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
+                gamma = scheduler.gamma
+                logger.info("Drop the weight decay by {}!".format(gamma))
+                for param in optimizer.param_groups:
+                    if 'weight_decay' in param:
+                        param['weight_decay'] *= gamma
+                # move the target forward
+                milestone_target += 1
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = reduce_loss_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+        meters.update(loss=losses_reduced, **loss_dict_reduced)
+        if model_ema is not None:
+            model_ema.update(model)
+            arguments["model_ema"] = model_ema.state_dict()
+
+        batch_time = time.time() - end
+        end = time.time()
+        meters.update(time=batch_time, data=data_time)
+        eta_seconds = meters.time.global_avg * (max_iter - iteration)
+        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+        if iteration % 20 == 0 or iteration == max_iter:
+        # if iteration % 1 == 0 or iteration == max_iter:
+            #logger.info(
+            if global_rank <= 0:
+                print(
+                    meters.delimiter.join(
+                        [
+                            "eta: {eta}",
+                            "iter: {iter}",
+                            "{meters}",
+                            "lr: {lr:.6f}",
+                            "wd: {wd:.6f}",
+                            "max mem: {memory:.0f}",
+                        ]
+                    ).format(
+                        eta=eta_string,
+                        iter=iteration,
+                        meters=str(meters),
+                        lr=optimizer.param_groups[0]["lr"],
+                        wd=optimizer.param_groups[0]["weight_decay"],
+                        memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+                    )
+                )
+        if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
+            if is_main_process():
+                print("Evaluating")
+            eval_result = 0.0
+            model.eval()
+            if cfg.SOLVER.TEST_WITH_INFERENCE:
+                with torch.no_grad():
+                    try:
+                        _model = model.module
+                    except:
+                        _model = model
+                    _result = inference(
+                        model = _model,
+                        data_loader = val_data_loader,
+                        dataset_name="val",
+                        device=device,
+                        expected_results=cfg.TEST.EXPECTED_RESULTS,
+                        expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
+                        output_folder=None,
+                        cfg=cfg,
+                        verbose=False
+                    )
+                    if is_main_process():
+                        eval_result = _result[0].results['bbox']['AP']
+            else:
+                results_dict = {}
+                cpu_device = torch.device("cpu")
+                for i, batch in enumerate(val_data_loader):
+                    images, targets, image_ids, positive_map, *_ = batch
+                    with torch.no_grad():
+                        images = images.to(device)
+                        if positive_map is None:
+                            output = model(images)
+                        else:
+                            captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
+                            output = model(images, captions, positive_map)
+                        output = [o.to(cpu_device) for o in output]
+                    results_dict.update(
+                        {img_id: result for img_id, result in zip(image_ids, output)}
+                    )
+                all_predictions = all_gather(results_dict)
+                if is_main_process():
+                    predictions = {}
+                    for p in all_predictions:
+                        predictions.update(p)
+                    predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
+                    eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
+                                            box_only=cfg.DATASETS.CLASS_AGNOSTIC)
+                    if cfg.DATASETS.CLASS_AGNOSTIC:
+                        eval_result = eval_result.results['box_proposal']['AR@100']
+                    else:
+                        eval_result = eval_result.results['bbox']['AP']
+            model.train()
+
+            if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR:
+                model_ema.ema.eval()
+                results_dict = {}
+                cpu_device = torch.device("cpu")
+                for i, batch in enumerate(val_data_loader):
+                    images, targets, image_ids, positive_map, positive_map_eval = batch
+                    with torch.no_grad():
+                        images = images.to(device)
+                        if positive_map is None:
+                            output = model_ema.ema(images)
+                        else:
+                            captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
+                            output = model_ema.ema(images, captions, positive_map)
+                        output = [o.to(cpu_device) for o in output]
+                    results_dict.update(
+                        {img_id: result for img_id, result in zip(image_ids, output)}
+                    )
+                all_predictions = all_gather(results_dict)
+                if is_main_process():
+                    predictions = {}
+                    for p in all_predictions:
+                        predictions.update(p)
+                    predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
+                    eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
+                                              box_only=cfg.DATASETS.CLASS_AGNOSTIC)
+                    if cfg.DATASETS.CLASS_AGNOSTIC:
+                        eval_result = eval_result.results['box_proposal']['AR@100']
+                    else:
+                        eval_result = eval_result.results['bbox']['AP']
+                
+            arguments.update(eval_result=eval_result)
+
+            if cfg.SOLVER.USE_AUTOSTEP:
+                eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0]
+                # print("Rank {} eval result gathered".format(cfg.local_rank), eval_result)
+                scheduler.step(eval_result)
+            
+            if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
+                if eval_result < previous_best:
+                    patience_counter += 1
+                else:
+                    patience_counter = 0
+                    previous_best = eval_result
+                    checkpointer.save("model_best", **arguments)
+                print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result)
+                if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE:
+                    if is_main_process():
+                        print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best))
+                    break
+
+        if iteration % checkpoint_period == 0:
+            checkpointer.save("model_{:07d}".format(iteration), **arguments)
+        if iteration == max_iter:
+            checkpointer.save("model_final", **arguments)
+            break
+
+    total_training_time = time.time() - start_training_time
+    total_time_str = str(datetime.timedelta(seconds=total_training_time))
+    logger.info(
+        "Total training time: {} ({:.4f} s / it)".format(
+            total_time_str, total_training_time / (max_iter)
+        )
+    )
diff --git a/maskrcnn_benchmark/layers/__init__.py b/maskrcnn_benchmark/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9d1db2e7b5328cd8231d8045e4bea0fc88dd934
--- /dev/null
+++ b/maskrcnn_benchmark/layers/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+from .batch_norm import FrozenBatchNorm2d, NaiveSyncBatchNorm2d
+from .misc import Conv2d, _NewEmptyTensorOp
+from .misc import ConvTranspose2d
+from .misc import DFConv2d
+from .misc import interpolate
+from .misc import Scale
+from .nms import nms
+from .nms import ml_nms
+from .nms import soft_nms
+from .roi_align import ROIAlign
+from .roi_align import roi_align
+from .roi_align import ROIAlignV2
+from .roi_pool import ROIPool
+from .roi_pool import roi_pool
+from .smooth_l1_loss import smooth_l1_loss
+from .sigmoid_focal_loss import SigmoidFocalLoss, TokenSigmoidFocalLoss
+from .iou_loss import IOULoss, IOUWHLoss
+from .deform_conv import DeformConv, ModulatedDeformConv
+from .dropblock import DropBlock2D, DropBlock3D
+from .evonorm import EvoNorm2d
+from .dyrelu import DYReLU, swish
+from .se import SELayer, SEBlock
+from .dyhead import DyHead
+from .set_loss import HungarianMatcher, SetCriterion
+
+__all__ = ["nms", "ml_nms", "soft_nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
+           "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "swish",
+           "FrozenBatchNorm2d", "NaiveSyncBatchNorm2d", "SigmoidFocalLoss", "TokenSigmoidFocalLoss", "IOULoss",
+           "IOUWHLoss", "Scale", "DeformConv", "ModulatedDeformConv", "DyHead",
+           "DropBlock2D", "DropBlock3D", "EvoNorm2d", "DYReLU", "SELayer", "SEBlock",
+           "HungarianMatcher", "SetCriterion", "ROIAlignV2", "_NewEmptyTensorOp"]
diff --git a/maskrcnn_benchmark/layers/batch_norm.py b/maskrcnn_benchmark/layers/batch_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2a83dadfa3aa52b3f854017a9fd71655c2a7c3
--- /dev/null
+++ b/maskrcnn_benchmark/layers/batch_norm.py
@@ -0,0 +1,117 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+
+import torch.distributed as dist
+import maskrcnn_benchmark.utils.comm as comm
+from torch.autograd.function import Function
+
+class FrozenBatchNorm2d(nn.Module):
+    """
+    BatchNorm2d where the batch statistics and the affine parameters
+    are fixed
+    """
+
+    def __init__(self, n):
+        super(FrozenBatchNorm2d, self).__init__()
+        self.register_buffer("weight", torch.ones(n))
+        self.register_buffer("bias", torch.zeros(n))
+        self.register_buffer("running_mean", torch.zeros(n))
+        self.register_buffer("running_var", torch.ones(n))
+
+    def forward(self, x):
+        scale = self.weight * self.running_var.rsqrt()
+        bias = self.bias - self.running_mean * scale
+        scale = scale.reshape(1, -1, 1, 1)
+        bias = bias.reshape(1, -1, 1, 1)
+        return x * scale + bias
+
+
+class AllReduce(Function):
+    @staticmethod
+    def forward(ctx, input):
+        input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
+        # Use allgather instead of allreduce since I don't trust in-place operations ..
+        dist.all_gather(input_list, input, async_op=False)
+        inputs = torch.stack(input_list, dim=0)
+        return torch.sum(inputs, dim=0)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        dist.all_reduce(grad_output, async_op=False)
+        return grad_output
+
+
+class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
+    """
+    In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
+    when the batch size on each worker is different.
+    (e.g., when scale augmentation is used, or when it is applied to mask head).
+
+    This is a slower but correct alternative to `nn.SyncBatchNorm`.
+
+    Note:
+        There isn't a single definition of Sync BatchNorm.
+
+        When ``stats_mode==""``, this module computes overall statistics by using
+        statistics of each worker with equal weight.  The result is true statistics
+        of all samples (as if they are all on one worker) only when all workers
+        have the same (N, H, W). This mode does not support inputs with zero batch size.
+
+        When ``stats_mode=="N"``, this module computes overall statistics by weighting
+        the statistics of each worker by their ``N``. The result is true statistics
+        of all samples (as if they are all on one worker) only when all workers
+        have the same (H, W). It is slower than ``stats_mode==""``.
+
+        Even though the result of this module may not be the true statistics of all samples,
+        it may still be reasonable because it might be preferrable to assign equal weights
+        to all workers, regardless of their (H, W) dimension, instead of putting larger weight
+        on larger images. From preliminary experiments, little difference is found between such
+        a simplified implementation and an accurate computation of overall mean & variance.
+    """
+
+    def __init__(self, *args, stats_mode="", **kwargs):
+        super().__init__(*args, **kwargs)
+        assert stats_mode in ["", "N"]
+        self._stats_mode = stats_mode
+
+    def forward(self, input):
+        if comm.get_world_size() == 1 or not self.training:
+            return super().forward(input)
+
+        B, C = input.shape[0], input.shape[1]
+
+        mean = torch.mean(input, dim=[0, 2, 3])
+        meansqr = torch.mean(input * input, dim=[0, 2, 3])
+
+        if self._stats_mode == "":
+            assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
+            vec = torch.cat([mean, meansqr], dim=0)
+            vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
+            mean, meansqr = torch.split(vec, C)
+            momentum = self.momentum
+        else:
+            if B == 0:
+                vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
+                vec = vec + input.sum()  # make sure there is gradient w.r.t input
+            else:
+                vec = torch.cat(
+                    [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
+                )
+            vec = AllReduce.apply(vec * B)
+
+            total_batch = vec[-1].detach()
+            momentum = total_batch.clamp(max=1) * self.momentum  # no update if total_batch is 0
+            total_batch = torch.max(total_batch, torch.ones_like(total_batch))  # avoid div-by-zero
+            mean, meansqr, _ = torch.split(vec / total_batch, C)
+
+        var = meansqr - mean * mean
+        invstd = torch.rsqrt(var + self.eps)
+        scale = self.weight * invstd
+        bias = self.bias - mean * scale
+        scale = scale.reshape(1, -1, 1, 1)
+        bias = bias.reshape(1, -1, 1, 1)
+
+        self.running_mean += momentum * (mean.detach() - self.running_mean)
+        self.running_var += momentum * (var.detach() - self.running_var)
+        return input * scale + bias
\ No newline at end of file
diff --git a/maskrcnn_benchmark/layers/deform_conv.py b/maskrcnn_benchmark/layers/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d78dcc10db200b7d8dae1fb4de252ba0868628
--- /dev/null
+++ b/maskrcnn_benchmark/layers/deform_conv.py
@@ -0,0 +1,436 @@
+import torch
+import math
+from torch import nn
+from torch.nn import init
+from torch.nn.modules.utils import _pair
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
+
+from maskrcnn_benchmark import _C
+
+class DeformConvFunction(Function):
+
+    @staticmethod
+    def forward(
+        ctx,
+        input,
+        offset,
+        weight,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        deformable_groups=1,
+        im2col_step=64
+    ):
+        if input is not None and input.dim() != 4:
+            raise ValueError(
+                "Expected 4D tensor as input, got {}D tensor instead.".format(
+                    input.dim()))
+        ctx.stride = _pair(stride)
+        ctx.padding = _pair(padding)
+        ctx.dilation = _pair(dilation)
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.im2col_step = im2col_step
+
+        ctx.save_for_backward(input, offset, weight)
+
+        output = input.new_empty(
+            DeformConvFunction._output_size(input, weight, ctx.padding,
+                                            ctx.dilation, ctx.stride))
+
+        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
+
+        if not input.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] %
+                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
+            _C.deform_conv_forward(
+                input,
+                weight,
+                offset,
+                output,
+                ctx.bufs_[0],
+                ctx.bufs_[1],
+                weight.size(3),
+                weight.size(2),
+                ctx.stride[1],
+                ctx.stride[0],
+                ctx.padding[1],
+                ctx.padding[0],
+                ctx.dilation[1],
+                ctx.dilation[0],
+                ctx.groups,
+                ctx.deformable_groups,
+                cur_im2col_step
+            )
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, offset, weight = ctx.saved_tensors
+
+        grad_input = grad_offset = grad_weight = None
+
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] %
+                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+                grad_input = torch.zeros_like(input)
+                grad_offset = torch.zeros_like(offset)
+                _C.deform_conv_backward_input(
+                    input,
+                    offset,
+                    grad_output,
+                    grad_input,
+                    grad_offset,
+                    weight,
+                    ctx.bufs_[0],
+                    weight.size(3),
+                    weight.size(2),
+                    ctx.stride[1],
+                    ctx.stride[0],
+                    ctx.padding[1],
+                    ctx.padding[0],
+                    ctx.dilation[1],
+                    ctx.dilation[0],
+                    ctx.groups,
+                    ctx.deformable_groups,
+                    cur_im2col_step
+                )
+
+            if ctx.needs_input_grad[2]:
+                grad_weight = torch.zeros_like(weight)
+                _C.deform_conv_backward_parameters(
+                    input,
+                    offset,
+                    grad_output,
+                    grad_weight,
+                    ctx.bufs_[0],
+                    ctx.bufs_[1],
+                    weight.size(3),
+                    weight.size(2),
+                    ctx.stride[1],
+                    ctx.stride[0],
+                    ctx.padding[1],
+                    ctx.padding[0],
+                    ctx.dilation[1],
+                    ctx.dilation[0],
+                    ctx.groups,
+                    ctx.deformable_groups,
+                    1,
+                    cur_im2col_step
+                )
+
+        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+    @staticmethod
+    def _output_size(input, weight, padding, dilation, stride):
+        channels = weight.size(0)
+        output_size = (input.size(0), channels)
+        for d in range(input.dim() - 2):
+            in_size = input.size(d + 2)
+            pad = padding[d]
+            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+            stride_ = stride[d]
+            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+        if not all(map(lambda s: s > 0, output_size)):
+            raise ValueError(
+                "convolution input is too small (output would be {})".format(
+                    'x'.join(map(str, output_size))))
+        return output_size
+
+class ModulatedDeformConvFunction(Function):
+
+    @staticmethod
+    def forward(
+        ctx,
+        input,
+        offset,
+        mask,
+        weight,
+        bias=None,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        deformable_groups=1
+    ):
+        ctx.stride = stride
+        ctx.padding = padding
+        ctx.dilation = dilation
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.with_bias = bias is not None
+        if not ctx.with_bias:
+            bias = input.new_empty(1)  # fake tensor
+        if not input.is_cuda:
+            raise NotImplementedError
+        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+                or input.requires_grad:
+            ctx.save_for_backward(input, offset, mask, weight, bias)
+        output = input.new_empty(
+            ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+        _C.modulated_deform_conv_forward(
+            input,
+            weight,
+            bias,
+            ctx._bufs[0],
+            offset,
+            mask,
+            output,
+            ctx._bufs[1],
+            weight.shape[2],
+            weight.shape[3],
+            ctx.stride,
+            ctx.stride,
+            ctx.padding,
+            ctx.padding,
+            ctx.dilation,
+            ctx.dilation,
+            ctx.groups,
+            ctx.deformable_groups,
+            ctx.with_bias
+        )
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        input, offset, mask, weight, bias = ctx.saved_tensors
+        grad_input = torch.zeros_like(input)
+        grad_offset = torch.zeros_like(offset)
+        grad_mask = torch.zeros_like(mask)
+        grad_weight = torch.zeros_like(weight)
+        grad_bias = torch.zeros_like(bias)
+        _C.modulated_deform_conv_backward(
+            input,
+            weight,
+            bias,
+            ctx._bufs[0],
+            offset,
+            mask,
+            ctx._bufs[1],
+            grad_input,
+            grad_weight,
+            grad_bias,
+            grad_offset,
+            grad_mask,
+            grad_output,
+            weight.shape[2],
+            weight.shape[3],
+            ctx.stride,
+            ctx.stride,
+            ctx.padding,
+            ctx.padding,
+            ctx.dilation,
+            ctx.dilation,
+            ctx.groups,
+            ctx.deformable_groups,
+            ctx.with_bias
+        )
+        if not ctx.with_bias:
+            grad_bias = None
+
+        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+                None, None, None, None, None)
+
+    @staticmethod
+    def _infer_shape(ctx, input, weight):
+        n = input.size(0)
+        channels_out = weight.size(0)
+        height, width = input.shape[2:4]
+        kernel_h, kernel_w = weight.shape[2:4]
+        height_out = (height + 2 * ctx.padding -
+                      (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+        width_out = (width + 2 * ctx.padding -
+                     (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+        return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        deformable_groups=1,
+        bias=False
+    ):
+        assert not bias
+        super(DeformConv, self).__init__()
+        self.with_bias = bias
+
+        assert in_channels % groups == 0, \
+            'in_channels {} cannot be divisible by groups {}'.format(
+                in_channels, groups)
+        assert out_channels % groups == 0, \
+            'out_channels {} cannot be divisible by groups {}'.format(
+                out_channels, groups)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+
+        self.weight = nn.Parameter(
+            torch.Tensor(out_channels, in_channels // self.groups,
+                         *self.kernel_size))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, input, offset):
+        return deform_conv(input, offset, self.weight, self.stride,
+                           self.padding, self.dilation, self.groups,
+                           self.deformable_groups)
+
+    def __repr__(self):
+        return "".join([
+            "{}(".format(self.__class__.__name__),
+            "in_channels={}, ".format(self.in_channels),
+            "out_channels={}, ".format(self.out_channels),
+            "kernel_size={}, ".format(self.kernel_size),
+            "stride={}, ".format(self.stride),
+            "dilation={}, ".format(self.dilation),
+            "padding={}, ".format(self.padding),
+            "groups={}, ".format(self.groups),
+            "deformable_groups={}, ".format(self.deformable_groups),
+            "bias={})".format(self.with_bias),
+        ])
+
+class ModulatedDeformConv(nn.Module):
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        deformable_groups=1,
+        bias=True
+    ):
+        super(ModulatedDeformConv, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+        self.with_bias = bias
+
+        self.weight = nn.Parameter(torch.Tensor(
+            out_channels,
+            in_channels // groups,
+            *self.kernel_size
+        ))
+        if bias:
+            self.bias = nn.Parameter(torch.Tensor(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.zero_()
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, input, offset, mask):
+        return modulated_deform_conv(
+            input, offset, mask, self.weight, self.bias, self.stride,
+            self.padding, self.dilation, self.groups, self.deformable_groups)
+
+    def __repr__(self):
+        return "".join([
+            "{}(".format(self.__class__.__name__),
+            "in_channels={}, ".format(self.in_channels),
+            "out_channels={}, ".format(self.out_channels),
+            "kernel_size={}, ".format(self.kernel_size),
+            "stride={}, ".format(self.stride),
+            "dilation={}, ".format(self.dilation),
+            "padding={}, ".format(self.padding),
+            "groups={}, ".format(self.groups),
+            "deformable_groups={}, ".format(self.deformable_groups),
+            "bias={})".format(self.with_bias),
+        ])
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deformable_groups=1,
+                 bias=True):
+        super(ModulatedDeformConvPack, self).__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, deformable_groups, bias)
+
+        self.conv_offset_mask = nn.Conv2d(
+            self.in_channels // self.groups,
+            self.deformable_groups * 3 * self.kernel_size[0] *
+            self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            bias=True)
+        self.init_offset()
+
+    def init_offset(self):
+        self.conv_offset_mask.weight.data.zero_()
+        self.conv_offset_mask.bias.data.zero_()
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, input):
+        out = self.conv_offset_mask(input)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+        return modulated_deform_conv(
+            input, offset, mask, self.weight, self.bias, self.stride,
+            self.padding, self.dilation, self.groups, self.deformable_groups)
diff --git a/maskrcnn_benchmark/layers/deform_pool.py b/maskrcnn_benchmark/layers/deform_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..1202e1d657c45cba7c8fa34a2684ae13e957ca30
--- /dev/null
+++ b/maskrcnn_benchmark/layers/deform_pool.py
@@ -0,0 +1,423 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .deform_conv import DeformConv2d
+
+def add_conv(in_ch, out_ch, ksize, stride, leaky=True):
+    """
+    Add a conv2d / batchnorm / leaky ReLU block.
+    Args:
+        in_ch (int): number of input channels of the convolution layer.
+        out_ch (int): number of output channels of the convolution layer.
+        ksize (int): kernel size of the convolution layer.
+        stride (int): stride of the convolution layer.
+    Returns:
+        stage (Sequential) : Sequential layers composing a convolution block.
+    """
+    stage = nn.Sequential()
+    pad = (ksize - 1) // 2
+    stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
+                                       out_channels=out_ch, kernel_size=ksize, stride=stride,
+                                       padding=pad, bias=False))
+    stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
+    if leaky:
+        stage.add_module('leaky', nn.LeakyReLU(0.1))
+    else:
+        stage.add_module('relu6', nn.ReLU6(inplace=True))
+    return stage
+
+
+class upsample(nn.Module):
+    __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name']
+
+    def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
+        super(upsample, self).__init__()
+        self.name = type(self).__name__
+        self.size = size
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, input):
+        return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
+
+    def extra_repr(self):
+        if self.scale_factor is not None:
+            info = 'scale_factor=' + str(self.scale_factor)
+        else:
+            info = 'size=' + str(self.size)
+        info += ', mode=' + self.mode
+        return info
+
+class SPPLayer(nn.Module):
+    def __init__(self):
+        super(SPPLayer, self).__init__()
+
+    def forward(self, x):
+        x_1 = x
+        x_2 = F.max_pool2d(x, 5, stride=1, padding=2)
+        x_3 = F.max_pool2d(x, 9, stride=1, padding=4)
+        x_4 = F.max_pool2d(x, 13, stride=1, padding=6)
+        out = torch.cat((x_1, x_2, x_3, x_4),dim=1)
+        return out
+
+class DropBlock(nn.Module):
+    def __init__(self, block_size=7, keep_prob=0.9):
+        super(DropBlock, self).__init__()
+        self.block_size = block_size
+        self.keep_prob = keep_prob
+        self.gamma = None
+        self.kernel_size = (block_size, block_size)
+        self.stride = (1, 1)
+        self.padding = (block_size//2, block_size//2)
+
+    def reset(self, block_size, keep_prob):
+        self.block_size = block_size
+        self.keep_prob = keep_prob
+        self.gamma = None
+        self.kernel_size = (block_size, block_size)
+        self.stride = (1, 1)
+        self.padding = (block_size//2, block_size//2)
+
+    def calculate_gamma(self, x):
+        return  (1-self.keep_prob) * x.shape[-1]**2/ \
+                (self.block_size**2 * (x.shape[-1] - self.block_size + 1)**2)
+
+    def forward(self, x):
+        if (not self.training or self.keep_prob==1): #set keep_prob=1 to turn off dropblock
+            return x
+        if self.gamma is None:
+            self.gamma = self.calculate_gamma(x)
+        if x.type() == 'torch.cuda.HalfTensor': #TODO: not fully support for FP16 now
+            FP16 = True
+            x = x.float()
+        else:
+            FP16 = False
+        p = torch.ones_like(x) * (self.gamma)
+        mask = 1 - torch.nn.functional.max_pool2d(torch.bernoulli(p),
+                                                  self.kernel_size,
+                                                  self.stride,
+                                                  self.padding)
+
+        out =  mask * x * (mask.numel()/mask.sum())
+
+        if FP16:
+            out = out.half()
+        return out
+
+class resblock(nn.Module):
+    """
+    Sequential residual blocks each of which consists of \
+    two convolution layers.
+    Args:
+        ch (int): number of input and output channels.
+        nblocks (int): number of residual blocks.
+        shortcut (bool): if True, residual tensor addition is enabled.
+    """
+    def __init__(self, ch, nblocks=1, shortcut=True):
+
+        super().__init__()
+        self.shortcut = shortcut
+        self.module_list = nn.ModuleList()
+        for i in range(nblocks):
+            resblock_one = nn.ModuleList()
+            resblock_one.append(add_conv(ch, ch//2, 1, 1))
+            resblock_one.append(add_conv(ch//2, ch, 3, 1))
+            self.module_list.append(resblock_one)
+
+    def forward(self, x):
+        for module in self.module_list:
+            h = x
+            for res in module:
+                h = res(h)
+            x = x + h if self.shortcut else h
+        return x
+
+
+class RFBblock(nn.Module):
+    def __init__(self,in_ch,residual=False):
+        super(RFBblock, self).__init__()
+        inter_c = in_ch // 4
+        self.branch_0 = nn.Sequential(
+            nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
+        )
+        self.branch_1 = nn.Sequential(
+            nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
+            nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1)
+        )
+        self.branch_2 = nn.Sequential(
+            nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
+            nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1),
+            nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=2, padding=2)
+        )
+        self.branch_3 = nn.Sequential(
+            nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
+            nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=5, stride=1, padding=2),
+            nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=3, padding=3)
+        )
+        self.residual= residual
+
+    def forward(self,x):
+        x_0 = self.branch_0(x)
+        x_1 = self.branch_1(x)
+        x_2 = self.branch_2(x)
+        x_3 = self.branch_3(x)
+        out = torch.cat((x_0,x_1,x_2,x_3),1)
+        if self.residual:
+            out +=x
+        return out
+
+
+class FeatureAdaption(nn.Module):
+    def __init__(self, in_ch, out_ch, n_anchors, rfb=False, sep=False):
+        super(FeatureAdaption, self).__init__()
+        if sep:
+            self.sep=True
+        else:
+            self.sep=False
+            self.conv_offset = nn.Conv2d(in_channels=2*n_anchors,
+                                         out_channels=2*9*n_anchors, groups = n_anchors, kernel_size=1,stride=1,padding=0)
+            self.dconv = DeformConv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1,
+                                      padding=1, deformable_groups=n_anchors)
+            self.rfb=None
+            if rfb:
+                self.rfb = RFBblock(out_ch)
+
+    def forward(self, input, wh_pred):
+        #The RFB block is added behind FeatureAdaption
+        #For mobilenet, we currently don't support rfb and FeatureAdaption
+        if self.sep:
+            return input
+        if self.rfb is not None:
+            input = self.rfb(input)
+        wh_pred_new = wh_pred.detach()
+        offset = self.conv_offset(wh_pred_new)
+        out = self.dconv(input, offset)
+        return out
+
+
+class ASFFmobile(nn.Module):
+    def __init__(self, level, rfb=False, vis=False):
+        super(ASFFmobile, self).__init__()
+        self.level = level
+        self.dim = [512, 256, 128]
+        self.inter_dim = self.dim[self.level]
+        if level==0:
+            self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2, leaky=False)
+            self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
+            self.expand = add_conv(self.inter_dim, 1024, 3, 1, leaky=False)
+        elif level==1:
+            self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
+            self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
+            self.expand = add_conv(self.inter_dim, 512, 3, 1, leaky=False)
+        elif level==2:
+            self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
+            self.compress_level_1 = add_conv(256, self.inter_dim, 1, 1, leaky=False)
+            self.expand = add_conv(self.inter_dim, 256, 3, 1,leaky=False)
+
+        compress_c = 8 if rfb else 16  #when adding rfb, we use half number of channels to save memory
+
+        self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
+        self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
+        self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
+
+        self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
+        self.vis= vis
+
+
+    def forward(self, x_level_0, x_level_1, x_level_2):
+        if self.level==0:
+            level_0_resized = x_level_0
+            level_1_resized = self.stride_level_1(x_level_1)
+
+            level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)
+            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
+
+        elif self.level==1:
+            level_0_compressed = self.compress_level_0(x_level_0)
+            level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
+            level_1_resized =x_level_1
+            level_2_resized =self.stride_level_2(x_level_2)
+        elif self.level==2:
+            level_0_compressed = self.compress_level_0(x_level_0)
+            level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
+            level_1_compressed = self.compress_level_1(x_level_1)
+            level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
+            level_2_resized =x_level_2
+
+        level_0_weight_v = self.weight_level_0(level_0_resized)
+        level_1_weight_v = self.weight_level_1(level_1_resized)
+        level_2_weight_v = self.weight_level_2(level_2_resized)
+        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
+        levels_weight = self.weight_levels(levels_weight_v)
+        levels_weight = F.softmax(levels_weight, dim=1)
+
+        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \
+                            level_1_resized * levels_weight[:,1:2,:,:]+ \
+                            level_2_resized * levels_weight[:,2:,:,:]
+
+        out = self.expand(fused_out_reduced)
+
+        if self.vis:
+            return out, levels_weight, fused_out_reduced.sum(dim=1)
+        else:
+            return out
+
+
+class ASFF(nn.Module):
+    def __init__(self, level, rfb=False, vis=False):
+        super(ASFF, self).__init__()
+        self.level = level
+        self.dim = [512, 256, 256]
+        self.inter_dim = self.dim[self.level]
+        if level==0:
+            self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2)
+            self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
+            self.expand = add_conv(self.inter_dim, 1024, 3, 1)
+        elif level==1:
+            self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
+            self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
+            self.expand = add_conv(self.inter_dim, 512, 3, 1)
+        elif level==2:
+            self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
+            self.expand = add_conv(self.inter_dim, 256, 3, 1)
+
+        compress_c = 8 if rfb else 16  #when adding rfb, we use half number of channels to save memory
+
+        self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1)
+        self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1)
+        self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1)
+
+        self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
+        self.vis= vis
+
+
+    def forward(self, x_level_0, x_level_1, x_level_2):
+        if self.level==0:
+            level_0_resized = x_level_0
+            level_1_resized = self.stride_level_1(x_level_1)
+
+            level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)
+            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
+
+        elif self.level==1:
+            level_0_compressed = self.compress_level_0(x_level_0)
+            level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
+            level_1_resized =x_level_1
+            level_2_resized =self.stride_level_2(x_level_2)
+        elif self.level==2:
+            level_0_compressed = self.compress_level_0(x_level_0)
+            level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
+            level_1_resized =F.interpolate(x_level_1, scale_factor=2, mode='nearest')
+            level_2_resized =x_level_2
+
+        level_0_weight_v = self.weight_level_0(level_0_resized)
+        level_1_weight_v = self.weight_level_1(level_1_resized)
+        level_2_weight_v = self.weight_level_2(level_2_resized)
+        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
+        levels_weight = self.weight_levels(levels_weight_v)
+        levels_weight = F.softmax(levels_weight, dim=1)
+
+        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \
+                            level_1_resized * levels_weight[:,1:2,:,:]+ \
+                            level_2_resized * levels_weight[:,2:,:,:]
+
+        out = self.expand(fused_out_reduced)
+
+        if self.vis:
+            return out, levels_weight, fused_out_reduced.sum(dim=1)
+        else:
+            return out
+
+def make_divisible(v, divisor, min_value=None):
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    :param v:
+    :param divisor:
+    :param min_value:
+    :return:
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+        padding = (kernel_size - 1) // 2
+        super(ConvBNReLU, self).__init__(
+            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.ReLU6(inplace=True)
+        )
+
+def add_sepconv(in_ch, out_ch, ksize, stride):
+
+    stage = nn.Sequential()
+    pad = (ksize - 1) // 2
+    stage.add_module('sepconv', nn.Conv2d(in_channels=in_ch,
+                                          out_channels=in_ch, kernel_size=ksize, stride=stride,
+                                          padding=pad, groups=in_ch, bias=False))
+    stage.add_module('sepbn', nn.BatchNorm2d(in_ch))
+    stage.add_module('seprelu6', nn.ReLU6(inplace=True))
+    stage.add_module('ptconv', nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False))
+    stage.add_module('ptbn', nn.BatchNorm2d(out_ch))
+    stage.add_module('ptrelu6', nn.ReLU6(inplace=True))
+    return stage
+
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, oup, stride, expand_ratio):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = int(round(inp * expand_ratio))
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        layers = []
+        if expand_ratio != 1:
+            # pw
+            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+        layers.extend([
+            # dw
+            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+            # pw-linear
+            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+            nn.BatchNorm2d(oup),
+        ])
+        self.conv = nn.Sequential(*layers)
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+class ressepblock(nn.Module):
+    def __init__(self, ch, out_ch, in_ch=None, shortcut=True):
+
+        super().__init__()
+        self.shortcut = shortcut
+        self.module_list = nn.ModuleList()
+        in_ch = ch//2 if in_ch==None else in_ch
+        resblock_one = nn.ModuleList()
+        resblock_one.append(add_conv(ch, in_ch, 1, 1, leaky=False))
+        resblock_one.append(add_conv(in_ch, out_ch, 3, 1,leaky=False))
+        self.module_list.append(resblock_one)
+
+    def forward(self, x):
+        for module in self.module_list:
+            h = x
+            for res in module:
+                h = res(h)
+            x = x + h if self.shortcut else h
+        return x
+
diff --git a/maskrcnn_benchmark/layers/dropblock.py b/maskrcnn_benchmark/layers/dropblock.py
new file mode 100644
index 0000000000000000000000000000000000000000..3210b99ec5d82d65e448363315df28c4c5f2d239
--- /dev/null
+++ b/maskrcnn_benchmark/layers/dropblock.py
@@ -0,0 +1,146 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class DropBlock2D(nn.Module):
+    r"""Randomly zeroes 2D spatial blocks of the input tensor.
+
+    As described in the paper
+    `DropBlock: A regularization method for convolutional networks`_ ,
+    dropping whole blocks of feature map allows to remove semantic
+    information as compared to regular dropout.
+
+    Args:
+        drop_prob (float): probability of an element to be dropped.
+        block_size (int): size of the block to drop
+
+    Shape:
+        - Input: `(N, C, H, W)`
+        - Output: `(N, C, H, W)`
+
+    .. _DropBlock: A regularization method for convolutional networks:
+       https://arxiv.org/abs/1810.12890
+
+    """
+
+    def __init__(self, drop_prob, block_size):
+        super(DropBlock2D, self).__init__()
+
+        self.drop_prob = drop_prob
+        self.block_size = block_size
+
+    def forward(self, x):
+        # shape: (bsize, channels, height, width)
+
+        assert x.dim() == 4, \
+            "Expected input with 4 dimensions (bsize, channels, height, width)"
+
+        if not self.training or self.drop_prob == 0.:
+            return x
+        else:
+            # get gamma value
+            gamma = self._compute_gamma(x)
+
+            # sample mask
+            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
+
+            # place mask on input device
+            mask = mask.to(x.device)
+
+            # compute block mask
+            block_mask = self._compute_block_mask(mask)
+
+            # apply block mask
+            out = x * block_mask[:, None, :, :]
+
+            # scale output
+            out = out * block_mask.numel() / block_mask.sum()
+
+            return out
+
+    def _compute_block_mask(self, mask):
+        block_mask = F.max_pool2d(input=mask[:, None, :, :],
+                                  kernel_size=(self.block_size, self.block_size),
+                                  stride=(1, 1),
+                                  padding=self.block_size // 2)
+
+        if self.block_size % 2 == 0:
+            block_mask = block_mask[:, :, :-1, :-1]
+
+        block_mask = 1 - block_mask.squeeze(1)
+
+        return block_mask
+
+    def _compute_gamma(self, x):
+        return self.drop_prob / (self.block_size ** 2)
+
+
+class DropBlock3D(DropBlock2D):
+    r"""Randomly zeroes 3D spatial blocks of the input tensor.
+
+    An extension to the concept described in the paper
+    `DropBlock: A regularization method for convolutional networks`_ ,
+    dropping whole blocks of feature map allows to remove semantic
+    information as compared to regular dropout.
+
+    Args:
+        drop_prob (float): probability of an element to be dropped.
+        block_size (int): size of the block to drop
+
+    Shape:
+        - Input: `(N, C, D, H, W)`
+        - Output: `(N, C, D, H, W)`
+
+    .. _DropBlock: A regularization method for convolutional networks:
+       https://arxiv.org/abs/1810.12890
+
+    """
+
+    def __init__(self, drop_prob, block_size):
+        super(DropBlock3D, self).__init__(drop_prob, block_size)
+
+    def forward(self, x):
+        # shape: (bsize, channels, depth, height, width)
+
+        assert x.dim() == 5, \
+            "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
+
+        if not self.training or self.drop_prob == 0.:
+            return x
+        else:
+            # get gamma value
+            gamma = self._compute_gamma(x)
+
+            # sample mask
+            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
+
+            # place mask on input device
+            mask = mask.to(x.device)
+
+            # compute block mask
+            block_mask = self._compute_block_mask(mask)
+
+            # apply block mask
+            out = x * block_mask[:, None, :, :, :]
+
+            # scale output
+            out = out * block_mask.numel() / block_mask.sum()
+
+            return out
+
+    def _compute_block_mask(self, mask):
+        block_mask = F.max_pool3d(input=mask[:, None, :, :, :],
+                                  kernel_size=(self.block_size, self.block_size, self.block_size),
+                                  stride=(1, 1, 1),
+                                  padding=self.block_size // 2)
+
+        if self.block_size % 2 == 0:
+            block_mask = block_mask[:, :, :-1, :-1, :-1]
+
+        block_mask = 1 - block_mask.squeeze(1)
+
+        return block_mask
+
+    def _compute_gamma(self, x):
+        return self.drop_prob / (self.block_size ** 3)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/layers/dyhead.py b/maskrcnn_benchmark/layers/dyhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..91fa88cb0beaef03e6459d671de843496ebe27f4
--- /dev/null
+++ b/maskrcnn_benchmark/layers/dyhead.py
@@ -0,0 +1,151 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .deform_conv import ModulatedDeformConv
+from .dyrelu import h_sigmoid, DYReLU
+
+
+class Conv3x3Norm(torch.nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 deformable=False,
+                 use_gn=False):
+        super(Conv3x3Norm, self).__init__()
+
+        if deformable:
+            self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
+        else:
+            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
+
+        if use_gn:
+            self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels)
+        else:
+            self.bn = None
+
+    def forward(self, input, **kwargs):
+        x = self.conv(input, **kwargs)
+        if self.bn:
+            x = self.bn(x)
+        return x
+
+
+class DyConv(nn.Module):
+    def __init__(self,
+                 in_channels=256,
+                 out_channels=256,
+                 conv_func=Conv3x3Norm,
+                 use_dyfuse=True,
+                 use_dyrelu=False,
+                 use_deform=False
+                 ):
+        super(DyConv, self).__init__()
+
+        self.DyConv = nn.ModuleList()
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 2))
+
+        if use_dyfuse:
+            self.AttnConv = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(in_channels, 1, kernel_size=1),
+                nn.ReLU(inplace=True))
+            self.h_sigmoid = h_sigmoid()
+        else:
+            self.AttnConv = None
+
+        if use_dyrelu:
+            self.relu = DYReLU(in_channels, out_channels)
+        else:
+            self.relu = nn.ReLU()
+
+        if use_deform:
+            self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
+        else:
+            self.offset = None
+
+        self.init_weights()
+
+    def init_weights(self):
+        for m in self.DyConv.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight.data, 0, 0.01)
+                if m.bias is not None:
+                    m.bias.data.zero_()
+        if self.AttnConv is not None:
+            for m in self.AttnConv.modules():
+                if isinstance(m, nn.Conv2d):
+                    nn.init.normal_(m.weight.data, 0, 0.01)
+                    if m.bias is not None:
+                        m.bias.data.zero_()
+
+    def forward(self, x):
+        next_x = []
+        for level, feature in enumerate(x):
+
+            conv_args = dict()
+            if self.offset is not None:
+                offset_mask = self.offset(feature)
+                offset = offset_mask[:, :18, :, :]
+                mask = offset_mask[:, 18:, :, :].sigmoid()
+                conv_args = dict(offset=offset, mask=mask)
+
+            temp_fea = [self.DyConv[1](feature, **conv_args)]
+
+            if level > 0:
+                temp_fea.append(self.DyConv[2](x[level - 1], **conv_args))
+            if level < len(x) - 1:
+                temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args),
+                                                    size=[feature.size(2), feature.size(3)]))
+            mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
+
+            if self.AttnConv is not None:
+                attn_fea = []
+                res_fea = []
+                for fea in temp_fea:
+                    res_fea.append(fea)
+                    attn_fea.append(self.AttnConv(fea))
+
+                res_fea = torch.stack(res_fea)
+                spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
+
+                mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
+
+            next_x.append(self.relu(mean_fea))
+
+        return next_x
+
+
+class DyHead(nn.Module):
+    def __init__(self, cfg, in_channels):
+        super(DyHead, self).__init__()
+        self.cfg = cfg
+        channels    = cfg.MODEL.DYHEAD.CHANNELS
+        use_gn      = cfg.MODEL.DYHEAD.USE_GN
+        use_dyrelu  = cfg.MODEL.DYHEAD.USE_DYRELU
+        use_dyfuse  = cfg.MODEL.DYHEAD.USE_DYFUSE
+        use_deform  = cfg.MODEL.DYHEAD.USE_DFCONV
+
+        conv_func = lambda i,o,s : Conv3x3Norm(i,o,s,deformable=use_deform,use_gn=use_gn)
+
+        dyhead_tower = []
+        for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
+            dyhead_tower.append(
+                DyConv(
+                    in_channels if i == 0 else channels,
+                    channels,
+                    conv_func=conv_func,
+                    use_dyrelu=use_dyrelu,
+                    use_dyfuse=use_dyfuse,
+                    use_deform=use_deform
+                )
+            )
+
+        self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
+
+    def forward(self, x):
+        dyhead_tower = self.dyhead_tower(x)
+        return dyhead_tower
\ No newline at end of file
diff --git a/maskrcnn_benchmark/layers/dyrelu.py b/maskrcnn_benchmark/layers/dyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..070b2e99df0f473faec0ef5914e1d385fda8e4f4
--- /dev/null
+++ b/maskrcnn_benchmark/layers/dyrelu.py
@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def _make_divisible(v, divisor, min_value=None):
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class swish(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class h_swish(nn.Module):
+    def __init__(self, inplace=False):
+        super(h_swish, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x):
+        return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
+
+
+class h_sigmoid(nn.Module):
+    def __init__(self, inplace=True, h_max=1):
+        super(h_sigmoid, self).__init__()
+        self.relu = nn.ReLU6(inplace=inplace)
+        self.h_max = h_max
+
+    def forward(self, x):
+        return self.relu(x + 3) * self.h_max / 6
+
+
+class DYReLU(nn.Module):
+    def __init__(self, inp, oup, reduction=4, lambda_a=1.0, K2=True, use_bias=True, use_spatial=False,
+                 init_a=[1.0, 0.0], init_b=[0.0, 0.0]):
+        super(DYReLU, self).__init__()
+        self.oup = oup
+        self.lambda_a = lambda_a * 2
+        self.K2 = K2
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+
+        self.use_bias = use_bias
+        if K2:
+            self.exp = 4 if use_bias else 2
+        else:
+            self.exp = 2 if use_bias else 1
+        self.init_a = init_a
+        self.init_b = init_b
+
+        # determine squeeze
+        if reduction == 4:
+            squeeze = inp // reduction
+        else:
+            squeeze = _make_divisible(inp // reduction, 4)
+        # print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
+        # print('init_a: {}, init_b: {}'.format(self.init_a, self.init_b))
+
+        self.fc = nn.Sequential(
+            nn.Linear(inp, squeeze),
+            nn.ReLU(inplace=True),
+            nn.Linear(squeeze, oup * self.exp),
+            h_sigmoid()
+        )
+        if use_spatial:
+            self.spa = nn.Sequential(
+                nn.Conv2d(inp, 1, kernel_size=1),
+                nn.BatchNorm2d(1),
+            )
+        else:
+            self.spa = None
+
+    def forward(self, x):
+        if isinstance(x, list):
+            x_in = x[0]
+            x_out = x[1]
+        else:
+            x_in = x
+            x_out = x
+        b, c, h, w = x_in.size()
+        y = self.avg_pool(x_in).view(b, c)
+        y = self.fc(y).view(b, self.oup * self.exp, 1, 1)
+        if self.exp == 4:
+            a1, b1, a2, b2 = torch.split(y, self.oup, dim=1)
+            a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0]  # 1.0
+            a2 = (a2 - 0.5) * self.lambda_a + self.init_a[1]
+
+            b1 = b1 - 0.5 + self.init_b[0]
+            b2 = b2 - 0.5 + self.init_b[1]
+            out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
+        elif self.exp == 2:
+            if self.use_bias:  # bias but not PL
+                a1, b1 = torch.split(y, self.oup, dim=1)
+                a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0]  # 1.0
+                b1 = b1 - 0.5 + self.init_b[0]
+                out = x_out * a1 + b1
+
+            else:
+                a1, a2 = torch.split(y, self.oup, dim=1)
+                a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0]  # 1.0
+                a2 = (a2 - 0.5) * self.lambda_a + self.init_a[1]
+                out = torch.max(x_out * a1, x_out * a2)
+
+        elif self.exp == 1:
+            a1 = y
+            a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0]  # 1.0
+            out = x_out * a1
+
+        if self.spa:
+            ys = self.spa(x_in).view(b, -1)
+            ys = F.softmax(ys, dim=1).view(b, 1, h, w) * h * w
+            ys = F.hardtanh(ys, 0, 3, inplace=True)/3
+            out = out * ys
+
+        return out
diff --git a/maskrcnn_benchmark/layers/evonorm.py b/maskrcnn_benchmark/layers/evonorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..058c0990427d0070fbb218e54119109536e80b66
--- /dev/null
+++ b/maskrcnn_benchmark/layers/evonorm.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+
+
+class EvoNorm2d(nn.Module):
+    __constants__ = ['num_features', 'eps', 'nonlinearity']
+
+    def __init__(self, num_features, eps=1e-5, nonlinearity=True, group=32):
+        super(EvoNorm2d, self).__init__()
+
+        self.num_features = num_features
+        self.eps = eps
+        self.nonlinearity = nonlinearity
+        self.group = group
+
+        self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
+        self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
+        if self.nonlinearity:
+            self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.init.ones_(self.weight)
+        nn.init.zeros_(self.bias)
+        if self.nonlinearity:
+            nn.init.ones_(self.v)
+
+    def group_std(self, x, groups=32):
+        N, C, H, W = x.shape
+        x = torch.reshape(x, (N, groups, C // groups, H, W))
+        std = torch.std(x, (3, 4), keepdim=True)
+        return torch.reshape(std + self.eps, (N, C, 1, 1))
+
+    def forward(self, x):
+        if self.nonlinearity:
+            num = x * torch.sigmoid(self.v * x)
+            return num / self.group_std(x, self.group) * self.weight + self.bias
+        else:
+            return x * self.weight + self.bias
\ No newline at end of file
diff --git a/maskrcnn_benchmark/layers/iou_loss.py b/maskrcnn_benchmark/layers/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..80703b20bfd0d443e66ee089f2d3653330238dbe
--- /dev/null
+++ b/maskrcnn_benchmark/layers/iou_loss.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class IOULoss(nn.Module):
+    def __init__(self, loss_type="iou"):
+        super(IOULoss, self).__init__()
+        self.loss_type = loss_type
+
+    def forward(self, pred, target, weight=None):
+        pred_left = pred[:, 0]
+        pred_top = pred[:, 1]
+        pred_right = pred[:, 2]
+        pred_bottom = pred[:, 3]
+
+        target_left = target[:, 0]
+        target_top = target[:, 1]
+        target_right = target[:, 2]
+        target_bottom = target[:, 3]
+
+        target_area = (target_left + target_right) * \
+                      (target_top + target_bottom)
+        pred_area = (pred_left + pred_right) * \
+                    (pred_top + pred_bottom)
+
+        w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right)
+        g_w_intersect = torch.max(pred_left, target_left) + torch.max(
+            pred_right, target_right)
+        h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top)
+        g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top)
+        ac_uion = g_w_intersect * g_h_intersect + 1e-7
+        area_intersect = w_intersect * h_intersect
+        area_union = target_area + pred_area - area_intersect
+        ious = (area_intersect + 1.0) / (area_union + 1.0)
+        gious = ious - (ac_uion - area_union) / ac_uion
+        if self.loss_type == 'iou':
+            losses = -torch.log(ious)
+        elif self.loss_type == 'linear_iou':
+            losses = 1 - ious
+        elif self.loss_type == 'giou':
+            losses = 1 - gious
+        else:
+            raise NotImplementedError
+
+        if weight is not None and weight.sum() > 0:
+            return (losses * weight).sum()
+        else:
+            assert losses.numel() != 0
+            return losses.sum()
+
+
+class IOUWHLoss(nn.Module):  # used for anchor guiding
+    def __init__(self, reduction='none'):
+        super(IOUWHLoss, self).__init__()
+        self.reduction = reduction
+
+    def forward(self, pred, target):
+        orig_shape = pred.shape
+        pred = pred.view(-1, 4)
+        target = target.view(-1, 4)
+        target[:, :2] = 0
+        tl = torch.max((target[:, :2] - pred[:, 2:] / 2),
+                       (target[:, :2] - target[:, 2:] / 2))
+
+        br = torch.min((target[:, :2] + pred[:, 2:] / 2),
+                       (target[:, :2] + target[:, 2:] / 2))
+
+        area_p = torch.prod(pred[:, 2:], 1)
+        area_g = torch.prod(target[:, 2:], 1)
+
+        en = (tl < br).type(tl.type()).prod(dim=1)
+        area_i = torch.prod(br - tl, 1) * en
+        U = area_p + area_g - area_i + 1e-16
+        iou = area_i / U
+
+        loss = 1 - iou ** 2
+        if self.reduction == 'mean':
+            loss = loss.mean()
+        elif self.reduction == 'sum':
+            loss = loss.sum()
+
+        return loss
diff --git a/maskrcnn_benchmark/layers/misc.py b/maskrcnn_benchmark/layers/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe175249c6494e19e3d077478025ef9e8335306d
--- /dev/null
+++ b/maskrcnn_benchmark/layers/misc.py
@@ -0,0 +1,205 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+helper class that supports empty tensors on some nn functions.
+
+Ideally, add support directly in PyTorch to empty tensors in
+those functions.
+
+This can be removed once https://github.com/pytorch/pytorch/issues/12013
+is implemented
+"""
+
+import math
+import torch
+from torch.nn.modules.utils import _ntuple
+
+
+class _NewEmptyTensorOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, new_shape):
+        ctx.shape = x.shape
+        return x.new_empty(new_shape)
+
+    @staticmethod
+    def backward(ctx, grad):
+        shape = ctx.shape
+        return _NewEmptyTensorOp.apply(grad, shape), None
+
+
+class Conv2d(torch.nn.Conv2d):
+    def forward(self, x):
+        if x.numel() > 0:
+            return super(Conv2d, self).forward(x)
+        # get output shape
+
+        output_shape = [
+            (i + 2 * p - (di * (k - 1) + 1)) // d + 1
+            for i, p, di, k, d in zip(
+                x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
+            )
+        ]
+        output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
+        return _NewEmptyTensorOp.apply(x, output_shape)
+
+
+class ConvTranspose2d(torch.nn.ConvTranspose2d):
+    def forward(self, x):
+        if x.numel() > 0:
+            return super(ConvTranspose2d, self).forward(x)
+        # get output shape
+
+        output_shape = [
+            (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
+            for i, p, di, k, d, op in zip(
+                x.shape[-2:],
+                self.padding,
+                self.dilation,
+                self.kernel_size,
+                self.stride,
+                self.output_padding,
+            )
+        ]
+        output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
+        return _NewEmptyTensorOp.apply(x, output_shape)
+
+
+class BatchNorm2d(torch.nn.BatchNorm2d):
+    def forward(self, x):
+        if x.numel() > 0:
+            return super(BatchNorm2d, self).forward(x)
+        # get output shape
+        output_shape = x.shape
+        return _NewEmptyTensorOp.apply(x, output_shape)
+
+
+def interpolate(
+    input, size=None, scale_factor=None, mode="nearest", align_corners=None
+):
+    if input.numel() > 0:
+        return torch.nn.functional.interpolate(
+            input, size, scale_factor, mode, align_corners
+        )
+
+    def _check_size_scale_factor(dim):
+        if size is None and scale_factor is None:
+            raise ValueError("either size or scale_factor should be defined")
+        if size is not None and scale_factor is not None:
+            raise ValueError("only one of size or scale_factor should be defined")
+        if (
+            scale_factor is not None
+            and isinstance(scale_factor, tuple)
+            and len(scale_factor) != dim
+        ):
+            raise ValueError(
+                "scale_factor shape must match input shape. "
+                "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
+            )
+
+    def _output_size(dim):
+        _check_size_scale_factor(dim)
+        if size is not None:
+            return size
+        scale_factors = _ntuple(dim)(scale_factor)
+        # math.floor might return float in py2.7
+        return [
+            int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
+        ]
+
+    output_shape = tuple(_output_size(2))
+    output_shape = input.shape[:-2] + output_shape
+    return _NewEmptyTensorOp.apply(input, output_shape)
+
+
+class Scale(torch.nn.Module):
+    def __init__(self, init_value=1.0):
+        super(Scale, self).__init__()
+        self.scale = torch.nn.Parameter(torch.FloatTensor([init_value]))
+
+    def forward(self, input):
+        return input * self.scale
+
+
+class DFConv2d(torch.nn.Module):
+    """Deformable convolutional layer"""
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        with_modulated_dcn=True,
+        kernel_size=3,
+        stride=1,
+        groups=1,
+        padding=1,
+        dilation=1,
+        deformable_groups=1,
+        bias=False
+    ):
+        super(DFConv2d, self).__init__()
+        if isinstance(kernel_size, (list, tuple)):
+            assert len(kernel_size) == 2
+            offset_base_channels = kernel_size[0] * kernel_size[1]
+        else:
+            offset_base_channels = kernel_size * kernel_size
+        if with_modulated_dcn:
+            from maskrcnn_benchmark.layers import ModulatedDeformConv
+            offset_channels = offset_base_channels * 3 #default: 27
+            conv_block = ModulatedDeformConv
+        else:
+            from maskrcnn_benchmark.layers import DeformConv
+            offset_channels = offset_base_channels * 2 #default: 18
+            conv_block = DeformConv
+        self.offset = Conv2d(
+            in_channels,
+            deformable_groups * offset_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            groups=1,
+            dilation=dilation
+        )
+        for l in [self.offset, ]:
+            torch.nn.init.kaiming_uniform_(l.weight, a=1)
+            torch.nn.init.constant_(l.bias, 0.)
+        self.conv = conv_block(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            deformable_groups=deformable_groups,
+            bias=bias
+        )
+        self.with_modulated_dcn = with_modulated_dcn
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.offset_base_channels = offset_base_channels
+
+    def forward(self, x):
+        if x.numel() > 0:
+            if not self.with_modulated_dcn:
+                offset = self.offset(x)
+                x = self.conv(x, offset)
+            else:
+                offset_mask = self.offset(x)
+                split_point = self.offset_base_channels * 2
+                offset = offset_mask[:, :split_point, :, :]
+                mask = offset_mask[:, split_point:, :, :].sigmoid()
+                x = self.conv(x, offset, mask)
+            return x
+        # get output shape
+        output_shape = [
+            (i + 2 * p - (di * (k - 1) + 1)) // d + 1
+            for i, p, di, k, d in zip(
+                x.shape[-2:],
+                self.padding,
+                self.dilation,
+                self.kernel_size,
+                self.stride
+            )
+        ]
+        output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape
+        return _NewEmptyTensorOp.apply(x, output_shape)
diff --git a/maskrcnn_benchmark/layers/nms.py b/maskrcnn_benchmark/layers/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e81ad4a2183b5fca497d33f0d34d5fcc0d4ea1
--- /dev/null
+++ b/maskrcnn_benchmark/layers/nms.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from maskrcnn_benchmark import _C
+
+try:
+    import torchvision
+    from torchvision.ops import nms
+except:
+    nms = _C.nms
+
+ml_nms = _C.ml_nms
+soft_nms = _C.soft_nms
+
+# nms.__doc__ = """
+# This function performs Non-maximum suppresion"""
diff --git a/maskrcnn_benchmark/layers/roi_align.py b/maskrcnn_benchmark/layers/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..247397098aaa7bf71bcb652af5a6664f86265ce9
--- /dev/null
+++ b/maskrcnn_benchmark/layers/roi_align.py
@@ -0,0 +1,89 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from maskrcnn_benchmark import _C
+
+class _ROIAlign(Function):
+    @staticmethod
+    def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
+        ctx.save_for_backward(roi)
+        ctx.output_size = _pair(output_size)
+        ctx.spatial_scale = spatial_scale
+        ctx.sampling_ratio = sampling_ratio
+        ctx.input_shape = input.size()
+        output = _C.roi_align_forward(
+            input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio
+        )
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        rois, = ctx.saved_tensors
+        output_size = ctx.output_size
+        spatial_scale = ctx.spatial_scale
+        sampling_ratio = ctx.sampling_ratio
+        bs, ch, h, w = ctx.input_shape
+        grad_input = _C.roi_align_backward(
+            grad_output,
+            rois,
+            spatial_scale,
+            output_size[0],
+            output_size[1],
+            bs,
+            ch,
+            h,
+            w,
+            sampling_ratio,
+        )
+        return grad_input, None, None, None, None
+
+try:
+    import torchvision
+    from torchvision.ops import roi_align
+except:
+    roi_align = _ROIAlign.apply
+
+class ROIAlign(nn.Module):
+    def __init__(self, output_size, spatial_scale, sampling_ratio):
+        super(ROIAlign, self).__init__()
+        self.output_size = output_size
+        self.spatial_scale = spatial_scale
+        self.sampling_ratio = sampling_ratio
+
+    def forward(self, input, rois):
+        return roi_align(
+            input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
+        )
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "output_size=" + str(self.output_size)
+        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
+        tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
+        tmpstr += ")"
+        return tmpstr
+
+class ROIAlignV2(nn.Module):
+    def __init__(self, output_size, spatial_scale, sampling_ratio):
+        super(ROIAlignV2, self).__init__()
+        self.output_size = output_size
+        self.spatial_scale = spatial_scale
+        self.sampling_ratio = sampling_ratio
+
+    def forward(self, input, rois):
+        return torchvision.ops.roi_align(
+            input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, aligned=True
+        )
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "output_size=" + str(self.output_size)
+        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
+        tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
+        tmpstr += ")"
+        return tmpstr
diff --git a/maskrcnn_benchmark/layers/roi_pool.py b/maskrcnn_benchmark/layers/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..b69efd75f17326a4fd8f306570624fd5ff4ef9b6
--- /dev/null
+++ b/maskrcnn_benchmark/layers/roi_pool.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from maskrcnn_benchmark import _C
+
+
+class _ROIPool(Function):
+    @staticmethod
+    def forward(ctx, input, roi, output_size, spatial_scale):
+        ctx.output_size = _pair(output_size)
+        ctx.spatial_scale = spatial_scale
+        ctx.input_shape = input.size()
+        output, argmax = _C.roi_pool_forward(
+            input, roi, spatial_scale, output_size[0], output_size[1]
+        )
+        ctx.save_for_backward(input, roi, argmax)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, rois, argmax = ctx.saved_tensors
+        output_size = ctx.output_size
+        spatial_scale = ctx.spatial_scale
+        bs, ch, h, w = ctx.input_shape
+        grad_input = _C.roi_pool_backward(
+            grad_output,
+            input,
+            rois,
+            argmax,
+            spatial_scale,
+            output_size[0],
+            output_size[1],
+            bs,
+            ch,
+            h,
+            w,
+        )
+        return grad_input, None, None, None
+
+
+roi_pool = _ROIPool.apply
+
+
+class ROIPool(nn.Module):
+    def __init__(self, output_size, spatial_scale):
+        super(ROIPool, self).__init__()
+        self.output_size = output_size
+        self.spatial_scale = spatial_scale
+
+    def forward(self, input, rois):
+        return roi_pool(input, rois, self.output_size, self.spatial_scale)
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "output_size=" + str(self.output_size)
+        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
+        tmpstr += ")"
+        return tmpstr
diff --git a/maskrcnn_benchmark/layers/se.py b/maskrcnn_benchmark/layers/se.py
new file mode 100644
index 0000000000000000000000000000000000000000..f10d09217270c14001fec2795b20d14dd5b73586
--- /dev/null
+++ b/maskrcnn_benchmark/layers/se.py
@@ -0,0 +1,52 @@
+from torch import nn
+
+
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel, bias=False),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
+
+
+class SEBlock(nn.Module):
+    def __init__(self, channels, reduction=16,
+                 use_conv=True, mid_activation=nn.ReLU(inplace=True), out_activation=nn.Sigmoid()):
+        super(SEBlock, self).__init__()
+        self.use_conv = use_conv
+        mid_channels = channels // reduction
+
+        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
+        if use_conv:
+            self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, bias=True)
+        else:
+            self.fc1 = nn.Linear(channels, mid_channels)
+        self.activ = mid_activation
+        if use_conv:
+            self.conv2 = nn.Conv2d(mid_channels, channels, kernel_size=1, bias=True)
+        else:
+            self.fc2 = nn.Linear(mid_channels, channels)
+        self.sigmoid = out_activation
+
+    def forward(self, x):
+        w = self.pool(x)
+        if not self.use_conv:
+            w = w.view(x.size(0), -1)
+        w = self.conv1(w) if self.use_conv else self.fc1(w)
+        w = self.activ(w)
+        w = self.conv2(w) if self.use_conv else self.fc2(w)
+        w = self.sigmoid(w)
+        if not self.use_conv:
+            w = w.unsqueeze(2).unsqueeze(3)
+        x = x * w
+        return x
\ No newline at end of file
diff --git a/maskrcnn_benchmark/layers/set_loss.py b/maskrcnn_benchmark/layers/set_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f504d813708e93b6c31efa61f98d0c8fa2cf0e9
--- /dev/null
+++ b/maskrcnn_benchmark/layers/set_loss.py
@@ -0,0 +1,371 @@
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch import nn
+
+from scipy.optimize import linear_sum_assignment
+from torch.cuda.amp import custom_fwd, custom_bwd
+
+
+def box_area(boxes):
+    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+    area1 = box_area(boxes1)
+    area2 = box_area(boxes2)
+
+    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
+    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    union = area1[:, None] + area2 - inter
+
+    iou = inter / union
+    return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+    """
+    Generalized IoU from https://giou.stanford.edu/
+
+    The boxes should be in [x0, y0, x1, y1] format
+
+    Returns a [N, M] pairwise matrix, where N = len(boxes1)
+    and M = len(boxes2)
+    """
+    # degenerate boxes gives inf / nan results
+    # so do an early check
+    #assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+    #assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+    iou, union = box_iou(boxes1, boxes2)
+
+    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    area = wh[:, :, 0] * wh[:, :, 1]
+
+    return iou - (area - union) / area
+
+
+def dice_loss(inputs, targets, num_boxes):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * (inputs * targets).sum(1)
+    denominator = inputs.sum(-1) + targets.sum(-1)
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss.sum() / num_boxes
+
+
+def sigmoid_focal_loss(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none"):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+        reduction: 'none' | 'mean' | 'sum'
+                 'none': No reduction will be applied to the output.
+                 'mean': The output will be averaged.
+                 'sum': The output will be summed.
+    Returns:
+        Loss tensor with the reduction option applied.
+    """
+    p = torch.sigmoid(inputs)
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = p * targets + (1 - p) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    if reduction == "mean":
+        loss = loss.mean()
+    elif reduction == "sum":
+        loss = loss.sum()
+
+    return loss
+
+
+sigmoid_focal_loss_jit = torch.jit.script(
+    sigmoid_focal_loss
+)  # type: torch.jit.ScriptModule
+
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1,
+                 use_focal: bool = False, focal_loss_alpha: float = 0.25, focal_loss_gamma: float = 2.0,
+                 **kwargs):
+        """Creates the matcher
+
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox = cost_bbox
+        self.cost_giou = cost_giou
+        self.use_focal = use_focal
+        if self.use_focal:
+            self.focal_loss_alpha = focal_loss_alpha
+            self.focal_loss_gamma = focal_loss_gamma
+        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+    @torch.no_grad()
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, outputs, targets):
+        """ Performs the matching
+
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        if self.use_focal:
+            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()  # [batch_size * num_queries, num_classes]
+            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
+        else:
+            out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
+            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
+
+        # Also concat the target labels and boxes
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        tgt_bbox = torch.cat([v["boxes_xyxy"] for v in targets])
+
+        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+        # but approximate it in 1 - proba[target class].
+        # The 1 is a constant that doesn't change the matching, it can be ommitted.
+        if self.use_focal:
+            # Compute the classification cost.
+            alpha = self.focal_loss_alpha
+            gamma = self.focal_loss_gamma
+            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+        else:
+            cost_class = -out_prob[:, tgt_ids]
+
+        # Compute the L1 cost between boxes
+        image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets])
+        image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1)
+        image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
+
+        out_bbox_ = out_bbox / image_size_out
+        tgt_bbox_ = tgt_bbox / image_size_tgt
+        cost_bbox = torch.cdist(out_bbox_, tgt_bbox_, p=1)
+
+        # Compute the giou cost betwen boxes
+        # cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+        cost_giou = -generalized_box_iou(out_bbox, tgt_bbox)
+
+        # Final cost matrix
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        C = C.view(bs, num_queries, -1).cpu()
+
+        C[torch.isnan(C)] = 0.0
+        C[torch.isinf(C)] = 0.0
+
+        sizes = [len(v["boxes"]) for v in targets]
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+class SetCriterion(nn.Module):
+    """
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
+                 use_focal, focal_loss_alpha=0.25, focal_loss_gamma=2.0):
+        """ Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.eos_coef = eos_coef
+        self.losses = losses
+        self.use_focal = use_focal
+        if self.use_focal:
+            self.focal_loss_alpha = focal_loss_alpha
+            self.focal_loss_gamma = focal_loss_gamma
+        else:
+            empty_weight = torch.ones(self.num_classes + 1)
+            empty_weight[-1] = self.eos_coef
+            self.register_buffer('empty_weight', empty_weight)
+
+    def loss_labels(self, outputs, targets, indices, num_boxes, log=False):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert 'pred_logits' in outputs
+        src_logits = outputs['pred_logits']
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+
+        if self.use_focal:
+            src_logits = src_logits.flatten(0, 1)
+            # prepare one_hot target.
+            target_classes = target_classes.flatten(0, 1)
+            pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0]
+            labels = torch.zeros_like(src_logits)
+            labels[pos_inds, target_classes[pos_inds]] = 1
+            # comp focal loss.
+            class_loss = sigmoid_focal_loss_jit(
+                src_logits,
+                labels,
+                alpha=self.focal_loss_alpha,
+                gamma=self.focal_loss_gamma,
+                reduction="sum",
+            ) / num_boxes
+            losses = {'loss_ce': class_loss}
+        else:
+            loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+            losses = {'loss_ce': loss_ce}
+
+        return losses
+
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+        """
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+        losses = {}
+        loss_giou = 1 - torch.diag(generalized_box_iou(src_boxes, target_boxes))
+        losses['loss_giou'] = loss_giou.sum() / num_boxes
+
+        image_size = torch.cat([v["image_size_xyxy_tgt"] for v in targets])
+        src_boxes_ = src_boxes / image_size
+        target_boxes_ = target_boxes / image_size
+
+        loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='none')
+        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'labels': self.loss_labels,
+            'boxes': self.loss_boxes,
+        }
+        assert loss in loss_map, f'do you really want to compute {loss} loss?'
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, outputs, targets, *argrs, **kwargs):
+        """ This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if dist.is_available() and dist.is_initialized():
+            torch.distributed.all_reduce(num_boxes)
+            word_size = dist.get_world_size()
+        else:
+            word_size = 1
+        num_boxes = torch.clamp(num_boxes / word_size, min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    if loss == 'masks':
+                        # Intermediate masks losses are too costly to compute, we ignore them.
+                        continue
+                    kwargs = {}
+                    if loss == 'labels':
+                        # Logging is enabled only for the last layer
+                        kwargs = {'log': False}
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
+
diff --git a/maskrcnn_benchmark/layers/sigmoid_focal_loss.py b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de1bb4c3a003ec705f721e7f1db04baf2e8b268
--- /dev/null
+++ b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py
@@ -0,0 +1,197 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from maskrcnn_benchmark import _C
+
+
+# TODO: Use JIT to replace CUDA implementation in the future.
+class _SigmoidFocalLoss(Function):
+    @staticmethod
+    def forward(ctx, logits, targets, gamma, alpha):
+        ctx.save_for_backward(logits, targets)
+        num_classes = logits.shape[1]
+        ctx.num_classes = num_classes
+        ctx.gamma = gamma
+        ctx.alpha = alpha
+
+        losses = _C.sigmoid_focalloss_forward(
+            logits, targets, num_classes, gamma, alpha
+        )
+        return losses
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, d_loss):
+        logits, targets = ctx.saved_tensors
+        num_classes = ctx.num_classes
+        gamma = ctx.gamma
+        alpha = ctx.alpha
+        d_loss = d_loss.contiguous()
+        d_logits = _C.sigmoid_focalloss_backward(
+            logits, targets, d_loss, num_classes, gamma, alpha
+        )
+        return d_logits, None, None, None, None
+
+
+sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply
+
+
+def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
+    num_classes = logits.shape[1]
+    dtype = targets.dtype
+    device = targets.device
+    class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0)
+
+    t = targets.unsqueeze(1)
+    p = torch.sigmoid(logits)
+    term1 = (1 - p) ** gamma * torch.log(p)
+    term2 = p ** gamma * torch.log(1 - p)
+    return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
+
+
+class SigmoidFocalLoss(nn.Module):
+    def __init__(self, gamma, alpha):
+        super(SigmoidFocalLoss, self).__init__()
+        self.gamma = gamma
+        self.alpha = alpha
+
+    def forward(self, logits, targets):
+        if logits.is_cuda:
+            loss_func = sigmoid_focal_loss_cuda
+        else:
+            loss_func = sigmoid_focal_loss_cpu
+
+        loss = loss_func(logits, targets, self.gamma, self.alpha)
+        return loss.sum()
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "gamma=" + str(self.gamma)
+        tmpstr += ", alpha=" + str(self.alpha)
+        tmpstr += ")"
+        return tmpstr
+
+
+def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
+    # Another modification is that because we use the cross entropy version, there is no frequent or not frequent class.
+    # So we temporarily retired the design of alpha.
+
+    assert (targets.dim() == 3)
+    assert (pred_logits.dim() == 3)  # batch x from x to
+
+    # reprocess target to become probability map ready for softmax
+    targets = targets.float()
+    target_num = targets.sum(-1) + 1e-8  # numerical stability
+    targets = targets / target_num.unsqueeze(-1)  # T(x)
+
+    if text_mask is not None:
+        # reserve the last token for non object
+        assert (text_mask.dim() == 2)
+        text_mask[:, -1] = 1
+        text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1)  # copy along the image channel
+        pred_logits = pred_logits.masked_fill(~text_mask, -1000000)  # softmax
+
+    out_prob = pred_logits.softmax(-1)
+
+    filled_targets = targets.clone()
+    filled_targets[filled_targets == 0] = 1.0
+
+    weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets
+    weight = torch.pow(weight, gamma)  # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma)
+
+    loss_ce = - targets * weight * pred_logits.log_softmax(
+        -1)  # only those positives with positive target_sim will have losses.
+    return loss_ce
+
+
+def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None):
+    assert (targets.dim() == 3)
+    assert (pred_logits.dim() == 3)  # batch x from x to
+
+    if text_mask is not None:
+        assert (text_mask.dim() == 2)
+
+    # We convert everything into binary
+    out_prob = pred_logits.sigmoid()
+    out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8  # batch x boxes x 256 x 2
+    weight = torch.pow(-out_prob_neg_pos + 1.0, gamma)
+
+    focal_zero = - weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * (
+            1 - alpha)  # negative class
+    focal_one = - weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha  # positive class
+    focal = torch.stack([focal_zero, focal_one], dim=-1)
+    loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1)
+    return loss_ce
+
+
+def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
+    # binary version of focal loss
+    # copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor with the reduction option applied.
+    """
+    assert (targets.dim() == 3)
+    assert (pred_logits.dim() == 3)  # batch x from x to
+
+    bs, n, _ = pred_logits.shape
+    if text_mask is not None:
+        assert (text_mask.dim() == 2)
+        text_mask = (text_mask > 0).unsqueeze(1)
+        text_mask = text_mask.repeat(1, pred_logits.size(1), 1)  # copy along the image channel dimension
+        pred_logits = torch.masked_select(pred_logits, text_mask)
+        targets = torch.masked_select(targets, text_mask)
+
+        # print(pred_logits.shape)
+        # print(targets.shape)
+
+    p = torch.sigmoid(pred_logits)
+    ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none")
+    p_t = p * targets + (1 - p) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss
+
+
+class TokenSigmoidFocalLoss(nn.Module):
+    def __init__(self, alpha, gamma):
+        super(TokenSigmoidFocalLoss, self).__init__()
+        self.alpha = alpha
+        self.gamma = gamma
+
+    def forward(self, logits, targets, text_masks=None, version="binary", **kwargs):
+        if version == "binary":
+            loss_func = token_sigmoid_binary_focal_loss
+        elif version == "softmax":
+            loss_func = token_sigmoid_softmax_focal_loss
+        elif version == "binaryv2":
+            loss_func = token_sigmoid_binary_focal_loss_v2
+        else:
+            raise NotImplementedError
+        loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs)
+        return loss.sum()
+
+    def __repr__(self):
+        tmpstr = self.__class__.__name__ + "("
+        tmpstr += "gamma=" + str(self.gamma)
+        tmpstr += ", alpha=" + str(self.alpha)
+        tmpstr += ")"
+        return tmpstr
diff --git a/maskrcnn_benchmark/layers/smooth_l1_loss.py b/maskrcnn_benchmark/layers/smooth_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2866f6c15f4f301f18181179b5fb835d4d6b7e8
--- /dev/null
+++ b/maskrcnn_benchmark/layers/smooth_l1_loss.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+
+# TODO maybe push this to nn?
+def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
+    """
+    very similar to the smooth_l1_loss from pytorch, but with
+    the extra beta parameter
+    """
+    n = torch.abs(input - target)
+    cond = n < beta
+    loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
+    if size_average:
+        return loss.mean()
+    return loss.sum()
diff --git a/maskrcnn_benchmark/modeling/.DS_Store b/maskrcnn_benchmark/modeling/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..57ad02856e1a722be1c2932bec7fda4b93bc20b9
Binary files /dev/null and b/maskrcnn_benchmark/modeling/.DS_Store differ
diff --git a/maskrcnn_benchmark/modeling/__init__.py b/maskrcnn_benchmark/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/modeling/backbone/__init__.py b/maskrcnn_benchmark/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8583983b31e2a7084858ae3d8bb1bb881978e0f
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/__init__.py
@@ -0,0 +1,239 @@
+from collections import OrderedDict
+
+from torch import nn
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform
+from maskrcnn_benchmark.layers import DropBlock2D, DyHead
+from . import fpn as fpn_module
+from . import bifpn
+from . import resnet
+from . import efficientnet
+from . import efficientdet
+from . import swint
+from . import swint_v2
+from . import swint_vl
+from . import swint_v2_vl
+
+
+@registry.BACKBONES.register("R-50-C4")
+@registry.BACKBONES.register("R-50-C5")
+@registry.BACKBONES.register("R-101-C4")
+@registry.BACKBONES.register("R-101-C5")
+def build_resnet_backbone(cfg):
+    body = resnet.ResNet(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+@registry.BACKBONES.register("R-50-RETINANET")
+@registry.BACKBONES.register("R-101-RETINANET")
+def build_resnet_c5_backbone(cfg):
+    body = resnet.ResNet(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+@registry.BACKBONES.register("SWINT-FPN-RETINANET")
+def build_retinanet_swint_fpn_backbone(cfg):
+    """
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    if cfg.MODEL.SWINT.VERSION == "v1":
+        body = swint.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "v2":
+        body = swint_v2.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "vl":
+        body = swint_vl.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "v2_vl":
+        body = swint_v2_vl.build_swint_backbone(cfg)
+
+    in_channels_stages = cfg.MODEL.SWINT.OUT_CHANNELS
+    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+    in_channels_p6p7 = out_channels
+    fpn = fpn_module.FPN(
+        in_channels_list=[
+            0,
+            in_channels_stages[-3],
+            in_channels_stages[-2],
+            in_channels_stages[-1],
+            ],
+        out_channels=out_channels,
+        conv_block=conv_with_kaiming_uniform(
+            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
+        ),
+        top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
+        drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None,
+        use_spp=cfg.MODEL.FPN.USE_SPP,
+        use_pan=cfg.MODEL.FPN.USE_PAN,
+        return_swint_feature_before_fusion=cfg.MODEL.FPN.RETURN_SWINT_FEATURE_BEFORE_FUSION
+    )
+    if cfg.MODEL.FPN.USE_DYHEAD:
+        dyhead = DyHead(cfg, out_channels)
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)]))
+    else:
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
+    return model
+
+
+@registry.BACKBONES.register("SWINT-FPN")
+def build_swint_fpn_backbone(cfg):
+    """
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    if cfg.MODEL.SWINT.VERSION == "v1":
+        body = swint.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "v2":
+        body = swint_v2.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "vl":
+        body = swint_vl.build_swint_backbone(cfg)
+    elif cfg.MODEL.SWINT.VERSION == "v2_vl":
+        body = swint_v2_vl.build_swint_backbone(cfg)
+
+    in_channels_stages = cfg.MODEL.SWINT.OUT_CHANNELS
+    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+    fpn = fpn_module.FPN(
+        in_channels_list=[
+            in_channels_stages[-4],
+            in_channels_stages[-3],
+            in_channels_stages[-2],
+            in_channels_stages[-1],
+            ],
+        out_channels=out_channels,
+        conv_block=conv_with_kaiming_uniform(
+            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
+        ),
+        top_blocks=fpn_module.LastLevelMaxPool(),
+        drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None,
+        use_spp=cfg.MODEL.FPN.USE_SPP,
+        use_pan=cfg.MODEL.FPN.USE_PAN
+    )
+    if cfg.MODEL.FPN.USE_DYHEAD:
+        dyhead = DyHead(cfg, out_channels)
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)]))
+    else:
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
+    return model
+
+
+@registry.BACKBONES.register("CVT-FPN-RETINANET")
+def build_retinanet_cvt_fpn_backbone(cfg):
+    """
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    body = cvt.build_cvt_backbone(cfg)
+    in_channels_stages = cfg.MODEL.SPEC.DIM_EMBED
+    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+    in_channels_p6p7 = out_channels
+    fpn = fpn_module.FPN(
+        in_channels_list=[
+            0,
+            in_channels_stages[-3],
+            in_channels_stages[-2],
+            in_channels_stages[-1],
+            ],
+        out_channels=out_channels,
+        conv_block=conv_with_kaiming_uniform(
+            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
+        ),
+        top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
+        drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None,
+        use_spp=cfg.MODEL.FPN.USE_SPP,
+        use_pan=cfg.MODEL.FPN.USE_PAN
+    )
+    if cfg.MODEL.FPN.USE_DYHEAD:
+        dyhead = DyHead(cfg, out_channels)
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)]))
+    else:
+        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
+    return model
+
+
+@registry.BACKBONES.register("EFFICIENT7-FPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT7-FPN-FCOS")
+@registry.BACKBONES.register("EFFICIENT5-FPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT5-FPN-FCOS")
+@registry.BACKBONES.register("EFFICIENT3-FPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT3-FPN-FCOS")
+def build_eff_fpn_p6p7_backbone(cfg):
+    version = cfg.MODEL.BACKBONE.CONV_BODY.split('-')[0]
+    version = version.replace('EFFICIENT', 'b')
+    body = efficientnet.get_efficientnet(cfg, version)
+    in_channels_stage = body.out_channels
+    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+    in_channels_p6p7 = out_channels
+    in_channels_stage[0] = 0
+    fpn = fpn_module.FPN(
+        in_channels_list=in_channels_stage,
+        out_channels=out_channels,
+        conv_block=conv_with_kaiming_uniform(
+            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
+        ),
+        top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
+        drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None,
+        use_spp=cfg.MODEL.FPN.USE_SPP,
+        use_pan=cfg.MODEL.FPN.USE_PAN
+    )
+    model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
+    return model
+
+
+@registry.BACKBONES.register("EFFICIENT7-BIFPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT7-BIFPN-FCOS")
+@registry.BACKBONES.register("EFFICIENT5-BIFPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT5-BIFPN-FCOS")
+@registry.BACKBONES.register("EFFICIENT3-BIFPN-RETINANET")
+@registry.BACKBONES.register("EFFICIENT3-BIFPN-FCOS")
+def build_eff_fpn_p6p7_backbone(cfg):
+    version = cfg.MODEL.BACKBONE.CONV_BODY.split('-')[0]
+    version = version.replace('EFFICIENT', 'b')
+    body = efficientnet.get_efficientnet(cfg, version)
+    in_channels_stage = body.out_channels
+    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+    bifpns = nn.ModuleList()
+    for i in range(cfg.MODEL.BIFPN.NUM_REPEATS):
+        first_time = (i==0)
+        fpn = bifpn.BiFPN(
+            in_channels_list=in_channels_stage[1:],
+            out_channels=out_channels,
+            first_time=first_time,
+            attention=cfg.MODEL.BIFPN.USE_ATTENTION
+        )
+        bifpns.append(fpn)
+    model = nn.Sequential(OrderedDict([("body", body), ("bifpn", bifpns)]))
+    return model
+
+
+@registry.BACKBONES.register("EFFICIENT-DET")
+def build_efficientdet_backbone(cfg):
+    efficientdet.g_simple_padding = True
+    compound = cfg.MODEL.BACKBONE.EFFICIENT_DET_COMPOUND
+    start_from = cfg.MODEL.BACKBONE.EFFICIENT_DET_START_FROM
+    model = efficientdet.EffNetFPN(
+        compound_coef=compound,
+        start_from=start_from,
+    )
+    if cfg.MODEL.BACKBONE.USE_SYNCBN:
+        import torch
+        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+    return model
+
+
+def build_backbone(cfg):
+    assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
+        "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
+            cfg.MODEL.BACKBONE.CONV_BODY
+        )
+    return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)
diff --git a/maskrcnn_benchmark/modeling/backbone/bifpn.py b/maskrcnn_benchmark/modeling/backbone/bifpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8689c1c32e61f7984559eb78e7f3e7828b3c2abc
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/bifpn.py
@@ -0,0 +1,273 @@
+import torch.nn as nn
+import torch
+
+from maskrcnn_benchmark.layers import swish
+
+
+class BiFPN(nn.Module):
+    def __init__(self, in_channels_list, out_channels, first_time=False, epsilon=1e-4, attention=True):
+        super(BiFPN, self).__init__()
+        self.epsilon = epsilon
+        # Conv layers
+        self.conv6_up = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv5_up = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv4_up = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv3_up = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv4_down = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv5_down = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv6_down = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+        self.conv7_down = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False),
+            nn.Conv2d(out_channels, out_channels, 1),
+            nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+        )
+
+        # Feature scaling layers
+        self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+
+        self.p4_downsample = nn.MaxPool2d(3, 2)
+        self.p5_downsample = nn.MaxPool2d(3, 2)
+        self.p6_downsample = nn.MaxPool2d(3, 2)
+        self.p7_downsample = nn.MaxPool2d(3, 2)
+
+        self.swish = swish()
+
+        self.first_time = first_time
+        if self.first_time:
+            self.p5_down_channel = nn.Sequential(
+                nn.Conv2d(in_channels_list[2], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p4_down_channel = nn.Sequential(
+                nn.Conv2d(in_channels_list[1], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p3_down_channel = nn.Sequential(
+                nn.Conv2d(in_channels_list[0], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+            )
+
+            self.p5_to_p6 = nn.Sequential(
+                nn.Conv2d(in_channels_list[2], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+                nn.MaxPool2d(3, 2)
+            )
+            self.p6_to_p7 = nn.Sequential(
+                nn.MaxPool2d(3, 2)
+            )
+
+            self.p4_down_channel_2 = nn.Sequential(
+                nn.Conv2d(in_channels_list[1], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p5_down_channel_2 = nn.Sequential(
+                nn.Conv2d(in_channels_list[2], out_channels, 1),
+                nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
+            )
+
+        # Weight
+        self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p6_w1_relu = nn.ReLU()
+        self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p5_w1_relu = nn.ReLU()
+        self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p4_w1_relu = nn.ReLU()
+        self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p3_w1_relu = nn.ReLU()
+
+        self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p4_w2_relu = nn.ReLU()
+        self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p5_w2_relu = nn.ReLU()
+        self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p6_w2_relu = nn.ReLU()
+        self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p7_w2_relu = nn.ReLU()
+
+        self.attention = attention
+
+    def forward(self, inputs):
+        """
+        illustration of a minimal bifpn unit
+            P7_0 -------------------------> P7_2 -------->
+               |-------------|                ↑
+                             ↓                |
+            P6_0 ---------> P6_1 ---------> P6_2 -------->
+               |-------------|--------------↑ ↑
+                             ↓                |
+            P5_0 ---------> P5_1 ---------> P5_2 -------->
+               |-------------|--------------↑ ↑
+                             ↓                |
+            P4_0 ---------> P4_1 ---------> P4_2 -------->
+               |-------------|--------------↑ ↑
+                             |--------------↓ |
+            P3_0 -------------------------> P3_2 -------->
+        """
+
+        # downsample channels using same-padding conv2d to target phase's if not the same
+        # judge: same phase as target,
+        # if same, pass;
+        # elif earlier phase, downsample to target phase's by pooling
+        # elif later phase, upsample to target phase's by nearest interpolation
+
+        if self.attention:
+            p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs)
+        else:
+            p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs)
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
+
+    def _forward_fast_attention(self, inputs):
+        if self.first_time:
+            p3, p4, p5 = inputs[-3:]
+
+            p6_in = self.p5_to_p6(p5)
+            p7_in = self.p6_to_p7(p6_in)
+
+            p3_in = self.p3_down_channel(p3)
+            p4_in = self.p4_down_channel(p4)
+            p5_in = self.p5_down_channel(p5)
+
+        else:
+            # P3_0, P4_0, P5_0, P6_0 and P7_0
+            p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+        # P7_0 to P7_2
+
+        # Weights for P6_0 and P7_0 to P6_1
+        p6_w1 = self.p6_w1_relu(self.p6_w1)
+        weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
+        # Connections for P6_0 and P7_0 to P6_1 respectively
+        p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
+
+        # Weights for P5_0 and P6_1 to P5_1
+        p5_w1 = self.p5_w1_relu(self.p5_w1)
+        weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
+        # Connections for P5_0 and P6_1 to P5_1 respectively
+        p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
+
+        # Weights for P4_0 and P5_1 to P4_1
+        p4_w1 = self.p4_w1_relu(self.p4_w1)
+        weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
+        # Connections for P4_0 and P5_1 to P4_1 respectively
+        p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
+
+        # Weights for P3_0 and P4_1 to P3_2
+        p3_w1 = self.p3_w1_relu(self.p3_w1)
+        weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
+        # Connections for P3_0 and P4_1 to P3_2 respectively
+        p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
+
+        if self.first_time:
+            p4_in = self.p4_down_channel_2(p4)
+            p5_in = self.p5_down_channel_2(p5)
+
+        # Weights for P4_0, P4_1 and P3_2 to P4_2
+        p4_w2 = self.p4_w2_relu(self.p4_w2)
+        weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
+        # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+        p4_out = self.conv4_down(
+            self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
+
+        # Weights for P5_0, P5_1 and P4_2 to P5_2
+        p5_w2 = self.p5_w2_relu(self.p5_w2)
+        weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
+        # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+        p5_out = self.conv5_down(
+            self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
+
+        # Weights for P6_0, P6_1 and P5_2 to P6_2
+        p6_w2 = self.p6_w2_relu(self.p6_w2)
+        weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
+        # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+        p6_out = self.conv6_down(
+            self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
+
+        # Weights for P7_0 and P6_2 to P7_2
+        p7_w2 = self.p7_w2_relu(self.p7_w2)
+        weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
+        # Connections for P7_0 and P6_2 to P7_2
+        p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
+
+    def _forward(self, inputs):
+        if self.first_time:
+            p3, p4, p5 = inputs
+
+            p6_in = self.p5_to_p6(p5)
+            p7_in = self.p6_to_p7(p6_in)
+
+            p3_in = self.p3_down_channel(p3)
+            p4_in = self.p4_down_channel(p4)
+            p5_in = self.p5_down_channel(p5)
+
+        else:
+            # P3_0, P4_0, P5_0, P6_0 and P7_0
+            p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+        # P7_0 to P7_2
+
+        # Connections for P6_0 and P7_0 to P6_1 respectively
+        p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
+
+        # Connections for P5_0 and P6_1 to P5_1 respectively
+        p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
+
+        # Connections for P4_0 and P5_1 to P4_1 respectively
+        p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
+
+        # Connections for P3_0 and P4_1 to P3_2 respectively
+        p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
+
+        if self.first_time:
+            p4_in = self.p4_down_channel_2(p4)
+            p5_in = self.p5_down_channel_2(p5)
+
+        # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+        p4_out = self.conv4_down(
+            self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
+
+        # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+        p5_out = self.conv5_down(
+            self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
+
+        # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+        p6_out = self.conv6_down(
+            self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
+
+        # Connections for P7_0 and P6_2 to P7_2
+        p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/blocks.py b/maskrcnn_benchmark/modeling/backbone/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..eab3b74a2e129abe07fb5d30776db77c46a648dd
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/blocks.py
@@ -0,0 +1,266 @@
+import torch.nn as nn
+from .ops import *
+
+
+class stem(nn.Module):
+    num_layer = 1
+
+    def __init__(self, conv, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d):
+        super(stem, self).__init__()
+
+        self.conv1 = conv(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        return out
+
+
+class basic(nn.Module):
+    expansion = 1
+    num_layer = 2
+
+    def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
+        super(basic, self).__init__()
+        midplanes = planes if midplanes is None else midplanes
+        self.conv1 = conv(inplanes, midplanes, stride)
+        self.bn1 = norm_layer(midplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv(midplanes, planes)
+        self.bn2 = norm_layer(planes)
+        if stride!=1 or inplanes!=planes*self.expansion:
+            self.downsample = nn.Sequential(
+                conv1x1(inplanes, planes, stride),
+                norm_layer(planes),
+            )
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class bottleneck(nn.Module):
+    expansion = 4
+    num_layer = 3
+
+    def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
+        super(bottleneck, self).__init__()
+        midplanes = planes if midplanes is None else midplanes
+        self.conv1 = conv1x1(inplanes, midplanes)
+        self.bn1 = norm_layer(midplanes)
+        self.conv2 = conv(midplanes, midplanes, stride)
+        self.bn2 = norm_layer(midplanes)
+        self.conv3 = conv1x1(midplanes, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        if stride!=1 or inplanes!=planes*self.expansion:
+            self.downsample = nn.Sequential(
+                conv1x1(inplanes, planes*self.expansion, stride),
+                norm_layer(planes*self.expansion),
+            )
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class invert(nn.Module):
+    def __init__(self, conv, inp, oup, stride=1, expand_ratio=1, norm_layer=nn.BatchNorm2d):
+        super(invert, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = round(inp * expand_ratio)
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        if expand_ratio == 1:
+            self.conv = nn.Sequential(
+                # dw
+                conv(hidden_dim, hidden_dim, stride),
+                norm_layer(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                norm_layer(oup),
+            )
+        else:
+            self.conv = nn.Sequential(
+                # pw
+                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+                norm_layer(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # dw
+                conv(hidden_dim, hidden_dim, stride),
+                norm_layer(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                norm_layer(oup),
+            )
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+invert2 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=2, **kwargs)
+invert3 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=3, **kwargs)
+invert4 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=4, **kwargs)
+invert6 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=6, **kwargs)
+
+
+def channel_shuffle(x, groups):
+    batchsize, num_channels, height, width = x.data.size()
+    channels_per_group = num_channels // groups
+    # reshape
+    x = x.view(batchsize, groups, channels_per_group, height, width)
+    x = torch.transpose(x, 1, 2).contiguous()
+    # flatten
+    x = x.view(batchsize, -1, height, width)
+    return x
+
+
+class shuffle(nn.Module):
+    expansion = 1
+    num_layer = 3
+
+    def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
+        super(shuffle, self).__init__()
+        inplanes = inplanes // 2 if stride == 1 else inplanes
+        midplanes = outplanes // 2 if midplanes is None else midplanes
+        rightoutplanes = outplanes - inplanes
+        if stride == 2:
+            self.left_branch = nn.Sequential(
+                # dw
+                conv(inplanes, inplanes, stride),
+                norm_layer(inplanes),
+                # pw-linear
+                conv1x1(inplanes, inplanes),
+                norm_layer(inplanes),
+                nn.ReLU(inplace=True),
+            )
+
+        self.right_branch = nn.Sequential(
+            # pw
+            conv1x1(inplanes, midplanes),
+            norm_layer(midplanes),
+            nn.ReLU(inplace=True),
+            # dw
+            conv(midplanes, midplanes, stride),
+            norm_layer(midplanes),
+            # pw-linear
+            conv1x1(midplanes, rightoutplanes),
+            norm_layer(rightoutplanes),
+            nn.ReLU(inplace=True),
+        )
+
+        self.reduce = stride==2
+
+    def forward(self, x):
+        if self.reduce:
+            out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
+        else:
+            x1 = x[:, :(x.shape[1]//2), :, :]
+            x2 = x[:, (x.shape[1]//2):, :, :]
+            out = torch.cat((x1, self.right_branch(x2)), 1)
+
+        return channel_shuffle(out, 2)
+
+
+class shufflex(nn.Module):
+    expansion = 1
+    num_layer = 3
+
+    def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
+        super(shufflex, self).__init__()
+        inplanes = inplanes // 2 if stride == 1 else inplanes
+        midplanes = outplanes // 2 if midplanes is None else midplanes
+        rightoutplanes = outplanes - inplanes
+        if stride==2:
+            self.left_branch = nn.Sequential(
+                # dw
+                conv(inplanes, inplanes, stride),
+                norm_layer(inplanes),
+                # pw-linear
+                conv1x1(inplanes, inplanes),
+                norm_layer(inplanes),
+                nn.ReLU(inplace=True),
+            )
+
+        self.right_branch = nn.Sequential(
+            # dw
+            conv(inplanes, inplanes, stride),
+            norm_layer(inplanes),
+            # pw-linear
+            conv1x1(inplanes, midplanes),
+            norm_layer(midplanes),
+            nn.ReLU(inplace=True),
+            # dw
+            conv(midplanes, midplanes, 1),
+            norm_layer(midplanes),
+            # pw-linear
+            conv1x1(midplanes, midplanes),
+            norm_layer(midplanes),
+            nn.ReLU(inplace=True),
+            # dw
+            conv(midplanes, midplanes, 1),
+            norm_layer(midplanes),
+            # pw-linear
+            conv1x1(midplanes, rightoutplanes),
+            norm_layer(rightoutplanes),
+            nn.ReLU(inplace=True),
+        )
+
+        self.reduce = stride==2
+
+    def forward(self, x):
+        if self.reduce:
+            out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
+        else:
+            x1 = x[:, :(x.shape[1] // 2), :, :]
+            x2 = x[:, (x.shape[1] // 2):, :, :]
+            out = torch.cat((x1, self.right_branch(x2)), 1)
+
+        return channel_shuffle(out, 2)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/efficientdet.py b/maskrcnn_benchmark/modeling/backbone/efficientdet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d5666815cd2e94c954929bd38786e89b4c19d89
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/efficientdet.py
@@ -0,0 +1,1882 @@
+import torch
+import re
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+import logging
+import cv2
+import math
+import itertools
+import collections
+from torchvision.ops import nms
+
+
+GlobalParams = collections.namedtuple('GlobalParams', [
+    'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
+    'num_classes', 'width_coefficient', 'depth_coefficient',
+    'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
+
+# Parameters for an individual model block
+BlockArgs = collections.namedtuple('BlockArgs', [
+    'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
+    'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
+
+# https://stackoverflow.com/a/18348004
+# Change namedtuple defaults
+GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
+BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
+
+# in the old version, g_simple_padding = False, which tries to align
+# tensorflow's implementation, which is not required here.
+g_simple_padding = True
+class MaxPool2dStaticSamePadding(nn.Module):
+    """
+    created by Zylo117
+    The real keras/tensorflow MaxPool2d with same padding
+    """
+
+    def __init__(self, kernel_size, stride):
+        super().__init__()
+        if g_simple_padding:
+            self.pool = nn.MaxPool2d(kernel_size, stride,
+                                     padding=(kernel_size-1)//2)
+        else:
+            assert ValueError()
+            self.pool = nn.MaxPool2d(kernel_size, stride)
+            self.stride = self.pool.stride
+            self.kernel_size = self.pool.kernel_size
+
+            if isinstance(self.stride, int):
+                self.stride = [self.stride] * 2
+            elif len(self.stride) == 1:
+                self.stride = [self.stride[0]] * 2
+
+            if isinstance(self.kernel_size, int):
+                self.kernel_size = [self.kernel_size] * 2
+            elif len(self.kernel_size) == 1:
+                self.kernel_size = [self.kernel_size[0]] * 2
+
+    def forward(self, x):
+        if g_simple_padding:
+            return self.pool(x)
+        else:
+            assert ValueError()
+            h, w = x.shape[-2:]
+
+            h_step = math.ceil(w / self.stride[1])
+            v_step = math.ceil(h / self.stride[0])
+            h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
+            v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
+
+            extra_h = h_cover_len - w
+            extra_v = v_cover_len - h
+
+            left = extra_h // 2
+            right = extra_h - left
+            top = extra_v // 2
+            bottom = extra_v - top
+
+            x = F.pad(x, [left, right, top, bottom])
+
+            x = self.pool(x)
+        return x
+
+class Conv2dStaticSamePadding(nn.Module):
+    """
+    created by Zylo117
+    The real keras/tensorflow conv2d with same padding
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
+        super().__init__()
+        if g_simple_padding:
+            assert kernel_size % 2 == 1
+            assert dilation == 1
+            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
+                                  bias=bias,
+                                  groups=groups,
+                                  padding=(kernel_size - 1) // 2)
+            self.stride = self.conv.stride
+            if isinstance(self.stride, int):
+                self.stride = [self.stride] * 2
+            elif len(self.stride) == 1:
+                self.stride = [self.stride[0]] * 2
+            else:
+                self.stride = list(self.stride)
+        else:
+            assert ValueError()
+            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
+                                  bias=bias, groups=groups)
+            self.stride = self.conv.stride
+            self.kernel_size = self.conv.kernel_size
+            self.dilation = self.conv.dilation
+
+            if isinstance(self.stride, int):
+                self.stride = [self.stride] * 2
+            elif len(self.stride) == 1:
+                self.stride = [self.stride[0]] * 2
+
+            if isinstance(self.kernel_size, int):
+                self.kernel_size = [self.kernel_size] * 2
+            elif len(self.kernel_size) == 1:
+                self.kernel_size = [self.kernel_size[0]] * 2
+
+    def forward(self, x):
+        if g_simple_padding:
+            return self.conv(x)
+        else:
+            assert ValueError()
+            h, w = x.shape[-2:]
+
+            h_step = math.ceil(w / self.stride[1])
+            v_step = math.ceil(h / self.stride[0])
+            h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
+            v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
+
+            extra_h = h_cover_len - w
+            extra_v = v_cover_len - h
+
+            left = extra_h // 2
+            right = extra_h - left
+            top = extra_v // 2
+            bottom = extra_v - top
+
+            x = F.pad(x, [left, right, top, bottom])
+
+            x = self.conv(x)
+            return x
+
+class SeparableConvBlock(nn.Module):
+    """
+    created by Zylo117
+    """
+
+    def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
+        super(SeparableConvBlock, self).__init__()
+        if out_channels is None:
+            out_channels = in_channels
+
+        # Q: whether separate conv
+        #  share bias between depthwise_conv and pointwise_conv
+        #  or just pointwise_conv apply bias.
+        # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
+
+        self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
+                                                      kernel_size=3, stride=1, groups=in_channels, bias=False)
+        self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
+
+        self.norm = norm
+        if self.norm:
+            # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
+            self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
+
+        self.activation = activation
+        if self.activation:
+            self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
+
+    def forward(self, x):
+        x = self.depthwise_conv(x)
+        x = self.pointwise_conv(x)
+
+        if self.norm:
+            x = self.bn(x)
+
+        if self.activation:
+            x = self.swish(x)
+
+        return x
+
+
+class BiFPN(nn.Module):
+    """
+    modified by Zylo117
+    """
+
+    def __init__(self, num_channels, conv_channels, first_time=False,
+                 epsilon=1e-4, onnx_export=False, attention=True,
+                 adaptive_up=False):
+        """
+
+        Args:
+            num_channels:
+            conv_channels:
+            first_time: whether the input comes directly from the efficientnet,
+                        if True, downchannel it first, and downsample P5 to generate P6 then P7
+            epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
+            onnx_export: if True, use Swish instead of MemoryEfficientSwish
+        """
+        super(BiFPN, self).__init__()
+        self.epsilon = epsilon
+        # Conv layers
+        self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+        self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
+
+        # Feature scaling layers
+        self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+        self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
+
+        self.adaptive_up = adaptive_up
+
+        self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
+        self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
+        self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
+        self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
+
+        self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
+
+        self.first_time = first_time
+        if self.first_time:
+            self.p5_down_channel = nn.Sequential(
+                Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+                nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p4_down_channel = nn.Sequential(
+                Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
+                nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p3_down_channel = nn.Sequential(
+                Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
+                nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+            )
+
+            if len(conv_channels) == 3:
+                self.p5_to_p6 = nn.Sequential(
+                    Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+                    nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+                    MaxPool2dStaticSamePadding(3, 2)
+                )
+            else:
+                assert len(conv_channels) == 4
+                self.p6_down_channel = nn.Sequential(
+                    Conv2dStaticSamePadding(conv_channels[3], num_channels, 1),
+                    nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+                )
+
+            self.p6_to_p7 = nn.Sequential(
+                MaxPool2dStaticSamePadding(3, 2)
+            )
+
+            self.p4_down_channel_2 = nn.Sequential(
+                Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
+                nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+            )
+            self.p5_down_channel_2 = nn.Sequential(
+                Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
+                nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
+            )
+
+        # Weight
+        self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p6_w1_relu = nn.ReLU()
+        self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p5_w1_relu = nn.ReLU()
+        self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p4_w1_relu = nn.ReLU()
+        self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p3_w1_relu = nn.ReLU()
+
+        self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p4_w2_relu = nn.ReLU()
+        self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p5_w2_relu = nn.ReLU()
+        self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
+        self.p6_w2_relu = nn.ReLU()
+        self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
+        self.p7_w2_relu = nn.ReLU()
+
+        self.attention = attention
+
+    def forward(self, inputs):
+        """
+        illustration of a minimal bifpn unit
+            P7_0 -------------------------> P7_2 -------->
+               |-------------|                ↑
+                             ↓                |
+            P6_0 ---------> P6_1 ---------> P6_2 -------->
+               |-------------|--------------↑ ↑
+                             ↓                |
+            P5_0 ---------> P5_1 ---------> P5_2 -------->
+               |-------------|--------------↑ ↑
+                             ↓                |
+            P4_0 ---------> P4_1 ---------> P4_2 -------->
+               |-------------|--------------↑ ↑
+                             |--------------↓ |
+            P3_0 -------------------------> P3_2 -------->
+        """
+
+        # downsample channels using same-padding conv2d to target phase's if not the same
+        # judge: same phase as target,
+        # if same, pass;
+        # elif earlier phase, downsample to target phase's by pooling
+        # elif later phase, upsample to target phase's by nearest interpolation
+        if self.attention:
+            p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs)
+        else:
+            p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs)
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
+
+    def _forward_fast_attention(self, inputs):
+        if self.first_time:
+            if len(inputs) == 3:
+                p3, p4, p5 = inputs
+                p6_in = self.p5_to_p6(p5)
+            else:
+                p3, p4, p5, p6 = inputs
+                p6_in = self.p6_down_channel(p6)
+
+            p7_in = self.p6_to_p7(p6_in)
+
+            p3_in = self.p3_down_channel(p3)
+            p4_in = self.p4_down_channel(p4)
+            p5_in = self.p5_down_channel(p5)
+        else:
+            # P3_0, P4_0, P5_0, P6_0 and P7_0
+            p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+        # P7_0 to P7_2
+
+        if not self.adaptive_up:
+            # Weights for P6_0 and P7_0 to P6_1
+            p6_w1 = self.p6_w1_relu(self.p6_w1)
+            weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
+            # Connections for P6_0 and P7_0 to P6_1 respectively
+            p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
+
+            # Weights for P5_0 and P6_0 to P5_1
+            p5_w1 = self.p5_w1_relu(self.p5_w1)
+            weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
+            # Connections for P5_0 and P6_0 to P5_1 respectively
+            p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
+
+            # Weights for P4_0 and P5_0 to P4_1
+            p4_w1 = self.p4_w1_relu(self.p4_w1)
+            weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
+            # Connections for P4_0 and P5_0 to P4_1 respectively
+            p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
+
+            # Weights for P3_0 and P4_1 to P3_2
+            p3_w1 = self.p3_w1_relu(self.p3_w1)
+            weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
+            # Connections for P3_0 and P4_1 to P3_2 respectively
+            p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
+        else:
+            # Weights for P6_0 and P7_0 to P6_1
+            p6_w1 = self.p6_w1_relu(self.p6_w1)
+            weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
+            # Connections for P6_0 and P7_0 to P6_1 respectively
+            p6_upsample = nn.Upsample(size=p6_in.shape[-2:])
+            p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * p6_upsample(p7_in)))
+
+            # Weights for P5_0 and P6_0 to P5_1
+            p5_w1 = self.p5_w1_relu(self.p5_w1)
+            weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
+            # Connections for P5_0 and P6_0 to P5_1 respectively
+            p5_upsample = nn.Upsample(size=p5_in.shape[-2:])
+            p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * p5_upsample(p6_up)))
+
+            # Weights for P4_0 and P5_0 to P4_1
+            p4_w1 = self.p4_w1_relu(self.p4_w1)
+            weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
+            # Connections for P4_0 and P5_0 to P4_1 respectively
+            p4_upsample = nn.Upsample(size=p4_in.shape[-2:])
+            p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * p4_upsample(p5_up)))
+
+            # Weights for P3_0 and P4_1 to P3_2
+            p3_w1 = self.p3_w1_relu(self.p3_w1)
+            weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
+            p3_upsample = nn.Upsample(size=p3_in.shape[-2:])
+            # Connections for P3_0 and P4_1 to P3_2 respectively
+            p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * p3_upsample(p4_up)))
+
+        if self.first_time:
+            p4_in = self.p4_down_channel_2(p4)
+            p5_in = self.p5_down_channel_2(p5)
+
+        # Weights for P4_0, P4_1 and P3_2 to P4_2
+        p4_w2 = self.p4_w2_relu(self.p4_w2)
+        weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
+        # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+        p4_out = self.conv4_down(
+            self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
+
+        # Weights for P5_0, P5_1 and P4_2 to P5_2
+        p5_w2 = self.p5_w2_relu(self.p5_w2)
+        weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
+        # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+        p5_out = self.conv5_down(
+            self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
+
+        # Weights for P6_0, P6_1 and P5_2 to P6_2
+        p6_w2 = self.p6_w2_relu(self.p6_w2)
+        weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
+        # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+        p6_out = self.conv6_down(
+            self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
+
+        # Weights for P7_0 and P6_2 to P7_2
+        p7_w2 = self.p7_w2_relu(self.p7_w2)
+        weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
+        # Connections for P7_0 and P6_2 to P7_2
+        p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
+
+    def _forward(self, inputs):
+        if self.first_time:
+            p3, p4, p5 = inputs
+
+            p6_in = self.p5_to_p6(p5)
+            p7_in = self.p6_to_p7(p6_in)
+
+            p3_in = self.p3_down_channel(p3)
+            p4_in = self.p4_down_channel(p4)
+            p5_in = self.p5_down_channel(p5)
+
+        else:
+            # P3_0, P4_0, P5_0, P6_0 and P7_0
+            p3_in, p4_in, p5_in, p6_in, p7_in = inputs
+
+        # P7_0 to P7_2
+
+        # Connections for P6_0 and P7_0 to P6_1 respectively
+        p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
+
+        # Connections for P5_0 and P6_0 to P5_1 respectively
+        p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
+
+        # Connections for P4_0 and P5_0 to P4_1 respectively
+        p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
+
+        # Connections for P3_0 and P4_1 to P3_2 respectively
+        p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
+
+        if self.first_time:
+            p4_in = self.p4_down_channel_2(p4)
+            p5_in = self.p5_down_channel_2(p5)
+
+        # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
+        p4_out = self.conv4_down(
+            self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
+
+        # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
+        p5_out = self.conv5_down(
+            self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
+
+        # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
+        p6_out = self.conv6_down(
+            self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
+
+        # Connections for P7_0 and P6_2 to P7_2
+        p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
+
+        return p3_out, p4_out, p5_out, p6_out, p7_out
+
+
+class Regressor(nn.Module):
+    """
+    modified by Zylo117
+    """
+
+    def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False):
+        super(Regressor, self).__init__()
+        self.num_layers = num_layers
+        self.num_layers = num_layers
+
+        self.conv_list = nn.ModuleList(
+            [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
+        self.bn_list = nn.ModuleList(
+            [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
+             range(5)])
+        self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
+        self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
+
+    def forward(self, inputs):
+        feats = []
+        for feat, bn_list in zip(inputs, self.bn_list):
+            for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
+                feat = conv(feat)
+                feat = bn(feat)
+                feat = self.swish(feat)
+            feat = self.header(feat)
+            feat = feat.permute(0, 2, 3, 1)
+            feat = feat.contiguous().view(feat.shape[0], -1, 4)
+
+            feats.append(feat)
+
+        feats = torch.cat(feats, dim=1)
+
+        return feats
+
+class SwishImplementation(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, i):
+        result = i * torch.sigmoid(i)
+        ctx.save_for_backward(i)
+        return result
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        i = ctx.saved_variables[0]
+        sigmoid_i = torch.sigmoid(i)
+        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+class MemoryEfficientSwish(nn.Module):
+    def forward(self, x):
+        if torch._C._get_tracing_state():
+            return x * torch.sigmoid(x)
+        return SwishImplementation.apply(x)
+
+class Swish(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+class Classifier(nn.Module):
+    """
+    modified by Zylo117
+    """
+
+    def __init__(self, in_channels, num_anchors, num_classes, num_layers,
+                 onnx_export=False, prior_prob=0.01):
+        super(Classifier, self).__init__()
+        self.num_anchors = num_anchors
+        self.num_classes = num_classes
+        self.num_layers = num_layers
+        self.conv_list = nn.ModuleList(
+            [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
+        self.bn_list = nn.ModuleList(
+            [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
+             range(5)])
+        self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
+
+        prior_prob = prior_prob
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        torch.nn.init.normal_(self.header.pointwise_conv.conv.weight, std=0.01)
+        torch.nn.init.constant_(self.header.pointwise_conv.conv.bias, bias_value)
+
+        self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
+
+    def forward(self, inputs):
+        feats = []
+        for feat, bn_list in zip(inputs, self.bn_list):
+            for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
+                feat = conv(feat)
+                feat = bn(feat)
+                feat = self.swish(feat)
+            feat = self.header(feat)
+
+            feat = feat.permute(0, 2, 3, 1)
+            feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors,
+                                          self.num_classes)
+            feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes)
+
+            feats.append(feat)
+
+        feats = torch.cat(feats, dim=1)
+        #feats = feats.sigmoid()
+
+        return feats
+
+class Conv2dDynamicSamePadding(nn.Conv2d):
+    """ 2D Convolutions like TensorFlow, for a dynamic image size """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
+        super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+        raise ValueError('tend to be deprecated')
+        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+    def forward(self, x):
+        ih, iw = x.size()[-2:]
+        kh, kw = self.weight.size()[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+#TODO: it seems like the standard conv layer is good enough with proper padding
+# parameters.
+def get_same_padding_conv2d(image_size=None):
+    """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+        Static padding is necessary for ONNX exporting of models. """
+    if image_size is None:
+        raise ValueError('not validated')
+        return Conv2dDynamicSamePadding
+    else:
+        from functools import partial
+        return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+def round_filters(filters, global_params):
+    """ Calculate and round number of filters based on depth multiplier. """
+    multiplier = global_params.width_coefficient
+    if not multiplier:
+        return filters
+    divisor = global_params.depth_divisor
+    min_depth = global_params.min_depth
+    filters *= multiplier
+    min_depth = min_depth or divisor
+    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
+        new_filters += divisor
+    return int(new_filters)
+
+def round_repeats(repeats, global_params):
+    """ Round number of filters based on depth multiplier. """
+    multiplier = global_params.depth_coefficient
+    if not multiplier:
+        return repeats
+    return int(math.ceil(multiplier * repeats))
+
+def drop_connect(inputs, p, training):
+    """ Drop connect. """
+    if not training: return inputs
+    batch_size = inputs.shape[0]
+    keep_prob = 1 - p
+    random_tensor = keep_prob
+    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
+    binary_tensor = torch.floor(random_tensor)
+    output = inputs / keep_prob * binary_tensor
+    return output
+
+class MBConvBlock(nn.Module):
+    """
+    Mobile Inverted Residual Bottleneck Block
+
+    Args:
+        block_args (namedtuple): BlockArgs, see above
+        global_params (namedtuple): GlobalParam, see above
+
+    Attributes:
+        has_se (bool): Whether the block contains a Squeeze and Excitation layer.
+    """
+
+    def __init__(self, block_args, global_params):
+        super().__init__()
+        self._block_args = block_args
+        self._bn_mom = 1 - global_params.batch_norm_momentum
+        self._bn_eps = global_params.batch_norm_epsilon
+        self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
+        self.id_skip = block_args.id_skip  # skip connection and drop connect
+
+        # Get static or dynamic convolution depending on image size
+        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
+
+        # Expansion phase
+        inp = self._block_args.input_filters  # number of input channels
+        oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
+        if self._block_args.expand_ratio != 1:
+            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
+            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+
+        # Depthwise convolution phase
+        k = self._block_args.kernel_size
+        s = self._block_args.stride
+        if isinstance(s, (tuple, list)) and all([s0 == s[0] for s0 in s]):
+            s = s[0]
+        self._depthwise_conv = Conv2d(
+            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
+            kernel_size=k, stride=s, bias=False)
+        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+
+        # Squeeze and Excitation layer, if desired
+        if self.has_se:
+            num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
+            self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
+            self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
+
+        # Output phase
+        final_oup = self._block_args.output_filters
+        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
+        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
+        self._swish = MemoryEfficientSwish()
+
+    def forward(self, inputs, drop_connect_rate=None):
+        """
+        :param inputs: input tensor
+        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
+        :return: output of block
+        """
+
+        # Expansion and Depthwise Convolution
+        x = inputs
+        if self._block_args.expand_ratio != 1:
+            x = self._expand_conv(inputs)
+            x = self._bn0(x)
+            x = self._swish(x)
+
+        x = self._depthwise_conv(x)
+        x = self._bn1(x)
+        x = self._swish(x)
+
+        # Squeeze and Excitation
+        if self.has_se:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._se_reduce(x_squeezed)
+            x_squeezed = self._swish(x_squeezed)
+            x_squeezed = self._se_expand(x_squeezed)
+            x = torch.sigmoid(x_squeezed) * x
+
+        x = self._project_conv(x)
+        x = self._bn2(x)
+
+        # Skip connection and drop connect
+        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
+        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
+            if drop_connect_rate:
+                x = drop_connect(x, p=drop_connect_rate, training=self.training)
+            x = x + inputs  # skip connection
+        return x
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export)"""
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+class BlockDecoder(object):
+    """ Block Decoder for readability, straight from the official TensorFlow repository """
+
+    @staticmethod
+    def _decode_block_string(block_string):
+        """ Gets a block through a string notation of arguments. """
+        assert isinstance(block_string, str)
+
+        ops = block_string.split('_')
+        options = {}
+        for op in ops:
+            splits = re.split(r'(\d.*)', op)
+            if len(splits) >= 2:
+                key, value = splits[:2]
+                options[key] = value
+
+        # Check stride
+        assert (('s' in options and len(options['s']) == 1) or
+                (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
+
+        return BlockArgs(
+            kernel_size=int(options['k']),
+            num_repeat=int(options['r']),
+            input_filters=int(options['i']),
+            output_filters=int(options['o']),
+            expand_ratio=int(options['e']),
+            id_skip=('noskip' not in block_string),
+            se_ratio=float(options['se']) if 'se' in options else None,
+            stride=[int(options['s'][0])])
+
+    @staticmethod
+    def _encode_block_string(block):
+        """Encodes a block to a string."""
+        args = [
+            'r%d' % block.num_repeat,
+            'k%d' % block.kernel_size,
+            's%d%d' % (block.strides[0], block.strides[1]),
+            'e%s' % block.expand_ratio,
+            'i%d' % block.input_filters,
+            'o%d' % block.output_filters
+        ]
+        if 0 < block.se_ratio <= 1:
+            args.append('se%s' % block.se_ratio)
+        if block.id_skip is False:
+            args.append('noskip')
+        return '_'.join(args)
+
+    @staticmethod
+    def decode(string_list):
+        """
+        Decodes a list of string notations to specify blocks inside the network.
+
+        :param string_list: a list of strings, each string is a notation of block
+        :return: a list of BlockArgs namedtuples of block args
+        """
+        assert isinstance(string_list, list)
+        blocks_args = []
+        for block_string in string_list:
+            blocks_args.append(BlockDecoder._decode_block_string(block_string))
+        return blocks_args
+
+    @staticmethod
+    def encode(blocks_args):
+        """
+        Encodes a list of BlockArgs to a list of strings.
+
+        :param blocks_args: a list of BlockArgs namedtuples of block args
+        :return: a list of strings, each string is a notation of block
+        """
+        block_strings = []
+        for block in blocks_args:
+            block_strings.append(BlockDecoder._encode_block_string(block))
+        return block_strings
+
+def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
+                 drop_connect_rate=0.2, image_size=None, num_classes=1000):
+    """ Creates a efficientnet model. """
+
+    blocks_args = [
+        'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
+        'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
+        'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
+        'r1_k3_s11_e6_i192_o320_se0.25',
+    ]
+    blocks_args = BlockDecoder.decode(blocks_args)
+
+    global_params = GlobalParams(
+        batch_norm_momentum=0.99,
+        batch_norm_epsilon=1e-3,
+        dropout_rate=dropout_rate,
+        drop_connect_rate=drop_connect_rate,
+        # data_format='channels_last',  # removed, this is always true in PyTorch
+        num_classes=num_classes,
+        width_coefficient=width_coefficient,
+        depth_coefficient=depth_coefficient,
+        depth_divisor=8,
+        min_depth=None,
+        image_size=image_size,
+    )
+
+    return blocks_args, global_params
+
+
+def efficientnet_params(model_name):
+    """ Map EfficientNet model name to parameter coefficients. """
+    params_dict = {
+        # Coefficients:   width,depth,res,dropout
+        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
+        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
+        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
+        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
+        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
+        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
+        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
+        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
+        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
+        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
+    }
+    return params_dict[model_name]
+
+
+def get_model_params(model_name, override_params):
+    """ Get the block args and global params for a given model """
+    if model_name.startswith('efficientnet'):
+        w, d, s, p = efficientnet_params(model_name)
+        # note: all models have drop connect rate = 0.2
+        blocks_args, global_params = efficientnet(
+            width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
+    else:
+        raise NotImplementedError('model name is not pre-defined: %s' % model_name)
+    if override_params:
+        # ValueError will be raised here if override_params has fields not included in global_params.
+        global_params = global_params._replace(**override_params)
+    return blocks_args, global_params
+
+url_map = {
+    'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
+    'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
+    'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
+    'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
+    'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
+    'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
+    'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
+    'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
+}
+
+url_map_advprop = {
+    'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
+    'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
+    'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
+    'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
+    'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
+    'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
+    'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
+    'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
+    'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
+}
+
+def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
+    """ Loads pretrained weights, and downloads if loading for the first time. """
+    # AutoAugment or Advprop (different preprocessing)
+    url_map_ = url_map_advprop if advprop else url_map
+    from torch.utils import model_zoo
+    state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu'))
+    # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth')
+    if load_fc:
+        ret = model.load_state_dict(state_dict, strict=False)
+        print(ret)
+    else:
+        state_dict.pop('_fc.weight')
+        state_dict.pop('_fc.bias')
+        res = model.load_state_dict(state_dict, strict=False)
+        assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
+    print('Loaded pretrained weights for {}'.format(model_name))
+
+class EfficientNet(nn.Module):
+    """
+    An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
+
+    Args:
+        blocks_args (list): A list of BlockArgs to construct blocks
+        global_params (namedtuple): A set of GlobalParams shared between blocks
+
+    Example:
+        model = EfficientNet.from_pretrained('efficientnet-b0')
+
+    """
+
+    def __init__(self, blocks_args=None, global_params=None):
+        super().__init__()
+        assert isinstance(blocks_args, list), 'blocks_args should be a list'
+        assert len(blocks_args) > 0, 'block args must be greater than 0'
+        self._global_params = global_params
+        self._blocks_args = blocks_args
+
+        # Get static or dynamic convolution depending on image size
+        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
+
+        # Batch norm parameters
+        bn_mom = 1 - self._global_params.batch_norm_momentum
+        bn_eps = self._global_params.batch_norm_epsilon
+
+        # Stem
+        in_channels = 3  # rgb
+        out_channels = round_filters(32, self._global_params)  # number of output channels
+        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Build blocks
+        self._blocks = nn.ModuleList([])
+        for block_args in self._blocks_args:
+
+            # Update block input and output filters based on depth multiplier.
+            block_args = block_args._replace(
+                input_filters=round_filters(block_args.input_filters, self._global_params),
+                output_filters=round_filters(block_args.output_filters, self._global_params),
+                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
+            )
+
+            # The first block needs to take care of stride and filter size increase.
+            self._blocks.append(MBConvBlock(block_args, self._global_params))
+            if block_args.num_repeat > 1:
+                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+            for _ in range(block_args.num_repeat - 1):
+                self._blocks.append(MBConvBlock(block_args, self._global_params))
+
+        # Head
+        in_channels = block_args.output_filters  # output of final block
+        out_channels = round_filters(1280, self._global_params)
+        self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+        # Final linear layer
+        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
+        self._dropout = nn.Dropout(self._global_params.dropout_rate)
+        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
+        self._swish = MemoryEfficientSwish()
+
+    def set_swish(self, memory_efficient=True):
+        """Sets swish function as memory efficient (for training) or standard (for export)"""
+        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+        for block in self._blocks:
+            block.set_swish(memory_efficient)
+
+    def extract_features(self, inputs):
+        """ Returns output of the final convolution layer """
+
+        # Stem
+        x = self._swish(self._bn0(self._conv_stem(inputs)))
+
+        # Blocks
+        for idx, block in enumerate(self._blocks):
+            drop_connect_rate = self._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(self._blocks)
+            x = block(x, drop_connect_rate=drop_connect_rate)
+        # Head
+        x = self._swish(self._bn1(self._conv_head(x)))
+
+        return x
+
+    def forward(self, inputs):
+        """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
+        bs = inputs.size(0)
+        # Convolution layers
+        x = self.extract_features(inputs)
+
+        # Pooling and final linear layer
+        x = self._avg_pooling(x)
+        x = x.view(bs, -1)
+        x = self._dropout(x)
+        x = self._fc(x)
+        return x
+
+    @classmethod
+    def from_name(cls, model_name, override_params=None):
+        cls._check_model_name_is_valid(model_name)
+        blocks_args, global_params = get_model_params(model_name, override_params)
+        return cls(blocks_args, global_params)
+
+    @classmethod
+    def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3):
+        model = cls.from_name(model_name, override_params={'num_classes': num_classes})
+        if load_weights:
+            load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
+        if in_channels != 3:
+            Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
+            out_channels = round_filters(32, model._global_params)
+            model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+        return model
+
+    @classmethod
+    def get_image_size(cls, model_name):
+        cls._check_model_name_is_valid(model_name)
+        _, _, res, _ = efficientnet_params(model_name)
+        return res
+
+    @classmethod
+    def _check_model_name_is_valid(cls, model_name):
+        """ Validates model name. """
+        valid_models = ['efficientnet-b'+str(i) for i in range(9)]
+        if model_name not in valid_models:
+            raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
+
+class EfficientNetD(nn.Module):
+    """
+    modified by Zylo117
+    """
+
+    def __init__(self, compound_coef, load_weights=False):
+        super().__init__()
+        model = EfficientNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights)
+        del model._conv_head
+        del model._bn1
+        del model._avg_pooling
+        del model._dropout
+        del model._fc
+        self.model = model
+
+    def forward(self, x):
+        x = self.model._conv_stem(x)
+        x = self.model._bn0(x)
+        x = self.model._swish(x)
+        feature_maps = []
+
+        # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea,
+        #  try recording stride changing when creating efficientnet,
+        #  and then apply it here.
+        last_x = None
+        for idx, block in enumerate(self.model._blocks):
+            drop_connect_rate = self.model._global_params.drop_connect_rate
+            if drop_connect_rate:
+                drop_connect_rate *= float(idx) / len(self.model._blocks)
+            x = block(x, drop_connect_rate=drop_connect_rate)
+
+            if tuple(block._depthwise_conv.stride) == (2, 2):
+                feature_maps.append(last_x)
+            elif idx == len(self.model._blocks) - 1:
+                feature_maps.append(x)
+            last_x = x
+        del last_x
+        return feature_maps[1:]
+
+class Anchors(nn.Module):
+    """
+    adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117
+    """
+
+    def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs):
+        super().__init__()
+        from qd.qd_common import print_frame_info
+        print_frame_info()
+        self.anchor_scale = anchor_scale
+
+        if pyramid_levels is None:
+            self.pyramid_levels = [3, 4, 5, 6, 7]
+
+        self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels])
+        self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
+        self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
+
+        self.buffer = {}
+
+    @torch.no_grad()
+    def forward(self, image, dtype=torch.float32, features=None):
+        """Generates multiscale anchor boxes.
+
+        Args:
+          image_size: integer number of input image size. The input image has the
+            same dimension for width and height. The image_size should be divided by
+            the largest feature stride 2^max_level.
+          anchor_scale: float number representing the scale of size of the base
+            anchor to the feature stride 2^level.
+          anchor_configs: a dictionary with keys as the levels of anchors and
+            values as a list of anchor configuration.
+
+        Returns:
+          anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all
+            feature levels.
+        Raises:
+          ValueError: input size must be the multiple of largest feature stride.
+        """
+        image_shape = image.shape[2:]
+        anchor_key = self.get_key('anchor', image_shape)
+        stride_idx_key = self.get_key('anchor_stride_index', image_shape)
+
+        if anchor_key in self.buffer:
+            return {'stride_idx': self.buffer[stride_idx_key].detach(),
+                    'anchor': self.buffer[anchor_key].detach()}
+
+        if dtype == torch.float16:
+            dtype = np.float16
+        else:
+            dtype = np.float32
+
+        boxes_all = []
+        all_idx_strides = []
+        for idx_stride, stride in enumerate(self.strides):
+            boxes_level = []
+            for scale, ratio in itertools.product(self.scales, self.ratios):
+                if features is not None:
+                    f_h, f_w = features[idx_stride].shape[-2:]
+                    x = np.arange(stride / 2, stride * f_w, stride)
+                    y = np.arange(stride / 2, stride * f_h, stride)
+                else:
+                    if image_shape[1] % stride != 0:
+                        x_max = stride * ((image_shape[1] + stride - 1) // stride)
+                        y_max = stride * ((image_shape[0] + stride - 1) // stride)
+                    else:
+                        x_max = image_shape[1]
+                        y_max = image_shape[0]
+                    x = np.arange(stride / 2, x_max, stride)
+                    y = np.arange(stride / 2, y_max, stride)
+                xv, yv = np.meshgrid(x, y)
+                xv = xv.reshape(-1)
+                yv = yv.reshape(-1)
+
+                base_anchor_size = self.anchor_scale * stride * scale
+                anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0
+                anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0
+                # y1,x1,y2,x2
+                boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
+                                   yv + anchor_size_y_2, xv + anchor_size_x_2))
+                boxes = np.swapaxes(boxes, 0, 1)
+                boxes_level.append(np.expand_dims(boxes, axis=1))
+            # concat anchors on the same level to the reshape NxAx4
+            boxes_level = np.concatenate(boxes_level, axis=1)
+            boxes_level = boxes_level.reshape([-1, 4])
+            idx_strides = torch.tensor([idx_stride] * len(boxes_level))
+            all_idx_strides.append(idx_strides)
+            boxes_all.append(boxes_level)
+
+        anchor_boxes = np.vstack(boxes_all)
+        anchor_stride_indices = torch.cat(all_idx_strides).to(image.device)
+
+        self.buffer[stride_idx_key] = anchor_stride_indices
+
+        anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device)
+        anchor_boxes = anchor_boxes.unsqueeze(0)
+
+        # save it for later use to reduce overhead
+        self.buffer[anchor_key] = anchor_boxes
+
+        return {'stride_idx': self.buffer[stride_idx_key],
+                'anchor': self.buffer[anchor_key]}
+
+    def get_key(self, hint, image_shape):
+        return '{}_{}'.format(hint, '_'.join(map(str, image_shape)))
+
+class EffNetFPN(nn.Module):
+    def __init__(self, compound_coef=0, start_from=3):
+        super().__init__()
+
+        self.backbone_net = EfficientNetD(EfficientDetBackbone.backbone_compound_coef[compound_coef],
+                                          load_weights=False)
+        if start_from == 3:
+            conv_channel_coef = EfficientDetBackbone.conv_channel_coef[compound_coef]
+        else:
+            conv_channel_coef = EfficientDetBackbone.conv_channel_coef2345[compound_coef]
+        self.bifpn = nn.Sequential(
+            *[BiFPN(EfficientDetBackbone.fpn_num_filters[compound_coef],
+                    conv_channel_coef,
+                    True if _ == 0 else False,
+                    attention=True if compound_coef < 6 else False,
+                    adaptive_up=True)
+              for _ in range(EfficientDetBackbone.fpn_cell_repeats[compound_coef])])
+
+        self.out_channels = EfficientDetBackbone.fpn_num_filters[compound_coef]
+
+        self.start_from = start_from
+        assert self.start_from in [2, 3]
+
+    def forward(self, inputs):
+        if self.start_from == 3:
+            _, p3, p4, p5 = self.backbone_net(inputs)
+
+            features = (p3, p4, p5)
+            features = self.bifpn(features)
+            return features
+        else:
+            p2, p3, p4, p5 = self.backbone_net(inputs)
+            features = (p2, p3, p4, p5)
+            features = self.bifpn(features)
+            return features
+
+class EfficientDetBackbone(nn.Module):
+    backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
+    fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
+    conv_channel_coef = {
+        # the channels of P3/P4/P5.
+        0: [40, 112, 320],
+        1: [40, 112, 320],
+        2: [48, 120, 352],
+        3: [48, 136, 384],
+        4: [56, 160, 448],
+        5: [64, 176, 512],
+        6: [72, 200, 576],
+        7: [72, 200, 576],
+    }
+    conv_channel_coef2345 = {
+        # the channels of P2/P3/P4/P5.
+        0: [24, 40, 112, 320],
+        # to be determined for the following
+        1: [24, 40, 112, 320],
+        2: [24, 48, 120, 352],
+        3: [32, 48, 136, 384],
+        4: [32, 56, 160, 448],
+        5: [40, 64, 176, 512],
+        6: [72, 200],
+        7: [72, 200],
+    }
+    fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
+    def __init__(self, num_classes=80, compound_coef=0, load_weights=False,
+                 prior_prob=0.01, **kwargs):
+        super(EfficientDetBackbone, self).__init__()
+        self.compound_coef = compound_coef
+
+        self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
+        self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
+        self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
+        self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
+        self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
+
+        num_anchors = len(self.aspect_ratios) * self.num_scales
+
+        self.bifpn = nn.Sequential(
+            *[BiFPN(self.fpn_num_filters[self.compound_coef],
+                    self.conv_channel_coef[compound_coef],
+                    True if _ == 0 else False,
+                    attention=True if compound_coef < 6 else False,
+                    adaptive_up=kwargs.get('adaptive_up'))
+              for _ in range(self.fpn_cell_repeats[compound_coef])])
+
+        self.num_classes = num_classes
+        self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
+                                   num_layers=self.box_class_repeats[self.compound_coef])
+        self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
+                                     num_classes=num_classes,
+                                     num_layers=self.box_class_repeats[self.compound_coef],
+                                     prior_prob=prior_prob)
+        anchor_scale = self.anchor_scale[compound_coef]
+        if kwargs.get('anchor_scale'):
+            anchor_scale = kwargs.pop('anchor_scale')
+        if 'anchor_scale' in kwargs:
+            del kwargs['anchor_scale']
+        self.anchors = Anchors(anchor_scale=anchor_scale, **kwargs)
+
+        self.backbone_net = EfficientNetD(self.backbone_compound_coef[compound_coef], load_weights)
+
+    def freeze_bn(self):
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eval()
+
+    def forward(self, inputs):
+        _, p3, p4, p5 = self.backbone_net(inputs)
+
+        features = (p3, p4, p5)
+        features = self.bifpn(features)
+
+        regression = self.regressor(features)
+        classification = self.classifier(features)
+        anchors = self.anchors(inputs, inputs.dtype, features=features)
+
+        return features, regression, classification, anchors
+
+    def init_backbone(self, path):
+        state_dict = torch.load(path)
+        try:
+            ret = self.load_state_dict(state_dict, strict=False)
+            print(ret)
+        except RuntimeError as e:
+            print('Ignoring ' + str(e) + '"')
+
+def init_weights(model):
+    for name, module in model.named_modules():
+        is_conv_layer = isinstance(module, nn.Conv2d)
+
+        if is_conv_layer:
+            nn.init.kaiming_uniform_(module.weight.data)
+
+            if module.bias is not None:
+                module.bias.data.zero_()
+
+def calc_iou(a, b):
+    # a(anchor) [boxes, (y1, x1, y2, x2)]
+    # b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
+
+    area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
+    iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
+    ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
+    iw = torch.clamp(iw, min=0)
+    ih = torch.clamp(ih, min=0)
+    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
+    ua = torch.clamp(ua, min=1e-8)
+    intersection = iw * ih
+    IoU = intersection / ua
+
+    return IoU
+
+class BBoxTransform(nn.Module):
+    def forward(self, anchors, regression):
+        """
+        decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py
+
+        Args:
+            anchors: [batchsize, boxes, (y1, x1, y2, x2)]
+            regression: [batchsize, boxes, (dy, dx, dh, dw)]
+
+        Returns:
+
+        """
+        y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2
+        x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2
+        ha = anchors[..., 2] - anchors[..., 0]
+        wa = anchors[..., 3] - anchors[..., 1]
+
+        w = regression[..., 3].exp() * wa
+        h = regression[..., 2].exp() * ha
+
+        y_centers = regression[..., 0] * ha + y_centers_a
+        x_centers = regression[..., 1] * wa + x_centers_a
+
+        ymin = y_centers - h / 2.
+        xmin = x_centers - w / 2.
+        ymax = y_centers + h / 2.
+        xmax = x_centers + w / 2.
+        if len(anchors.shape) == 3:
+            return torch.stack([xmin, ymin, xmax, ymax], dim=2)
+        else:
+            return torch.stack([xmin, ymin, xmax, ymax], dim=1)
+
+
+class ClipBoxes(nn.Module):
+
+    def __init__(self):
+        super(ClipBoxes, self).__init__()
+
+    def forward(self, boxes, img):
+        batch_size, num_channels, height, width = img.shape
+
+        boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
+        boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
+
+        boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1)
+        boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1)
+
+        return boxes
+
+def postprocess2(x, anchors, regression, classification,
+                 transformed_anchors, threshold, iou_threshold, max_box):
+    anchors = anchors['anchor']
+    all_above_th = classification > threshold
+    out = []
+    num_image = x.shape[0]
+    num_class = classification.shape[-1]
+
+    #classification = classification.cpu()
+    #transformed_anchors = transformed_anchors.cpu()
+    #all_above_th = all_above_th.cpu()
+    max_box_pre_nms = 1000
+    for i in range(num_image):
+        all_rois = []
+        all_class_ids = []
+        all_scores = []
+        for c in range(num_class):
+            above_th = all_above_th[i, :, c].nonzero()
+            if len(above_th) == 0:
+                continue
+            above_prob = classification[i, above_th, c].squeeze(1)
+            if len(above_th) > max_box_pre_nms:
+                _, idx = above_prob.topk(max_box_pre_nms)
+                above_th = above_th[idx]
+                above_prob = above_prob[idx]
+            transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1)
+            from torchvision.ops import nms
+            nms_idx = nms(transformed_anchors_per, above_prob, iou_threshold=iou_threshold)
+            if len(nms_idx) > 0:
+                all_rois.append(transformed_anchors_per[nms_idx])
+                ids = torch.tensor([c] * len(nms_idx))
+                all_class_ids.append(ids)
+                all_scores.append(above_prob[nms_idx])
+
+        if len(all_rois) > 0:
+            rois = torch.cat(all_rois)
+            class_ids = torch.cat(all_class_ids)
+            scores = torch.cat(all_scores)
+            if len(scores) > max_box:
+                _, idx = torch.topk(scores, max_box)
+                rois = rois[idx, :]
+                class_ids = class_ids[idx]
+                scores = scores[idx]
+            out.append({
+                'rois': rois,
+                'class_ids': class_ids,
+                'scores': scores,
+            })
+        else:
+            out.append({
+                'rois': [],
+                'class_ids': [],
+                'scores': [],
+            })
+
+    return out
+
+def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold):
+    anchors = anchors['anchor']
+    transformed_anchors = regressBoxes(anchors, regression)
+    transformed_anchors = clipBoxes(transformed_anchors, x)
+    scores = torch.max(classification, dim=2, keepdim=True)[0]
+    scores_over_thresh = (scores > threshold)[:, :, 0]
+    out = []
+    for i in range(x.shape[0]):
+        if scores_over_thresh.sum() == 0:
+            out.append({
+                'rois': [],
+                'class_ids': [],
+                'scores': [],
+            })
+            continue
+
+        classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0)
+        transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...]
+        scores_per = scores[i, scores_over_thresh[i, :], ...]
+        from torchvision.ops import nms
+        anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold)
+
+        if anchors_nms_idx.shape[0] != 0:
+            scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0)
+            boxes_ = transformed_anchors_per[anchors_nms_idx, :]
+
+            out.append({
+                'rois': boxes_,
+                'class_ids': classes_,
+                'scores': scores_,
+            })
+        else:
+            out.append({
+                'rois': [],
+                'class_ids': [],
+                'scores': [],
+            })
+
+    return out
+
+def display(preds, imgs, obj_list, imshow=True, imwrite=False):
+    for i in range(len(imgs)):
+        if len(preds[i]['rois']) == 0:
+            continue
+
+        for j in range(len(preds[i]['rois'])):
+            (x1, y1, x2, y2) = preds[i]['rois'][j].detach().cpu().numpy().astype(np.int)
+            logging.info((x1, y1, x2, y2))
+            cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
+            #obj = obj_list[preds[i]['class_ids'][j]]
+            #score = float(preds[i]['scores'][j])
+
+            #cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
+                        #(x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
+                        #(255, 255, 0), 1)
+            #break
+        if imshow:
+            cv2.imshow('image', imgs[i])
+            cv2.waitKey(0)
+
+def calculate_focal_loss2(classification, target_list, alpha, gamma):
+    from maskrcnn_benchmark.layers.sigmoid_focal_loss import sigmoid_focal_loss_cuda
+    cls_loss = sigmoid_focal_loss_cuda(classification, target_list.int(), gamma, alpha)
+    return cls_loss
+
+def calculate_focal_loss(classification, targets, alpha, gamma):
+    classification = classification.sigmoid()
+    device = classification.device
+    alpha_factor = torch.ones_like(targets) * alpha
+    alpha_factor = alpha_factor.to(device)
+
+    alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
+    focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
+    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
+
+    bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
+
+    cls_loss = focal_weight * bce
+
+    zeros = torch.zeros_like(cls_loss)
+    zeros = zeros.to(device)
+    cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
+    return cls_loss.mean()
+
+def calculate_giou(pred, gt):
+    ax1, ay1, ax2, ay2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
+    bx1, by1, bx2, by2 = gt[:, 0], gt[:, 1], gt[:, 2], gt[:, 3]
+    a = (ax2 - ax1) * (ay2 - ay1)
+    b = (bx2 - bx1) * (by2 - by1)
+    max_x1, _ = torch.max(torch.stack([ax1, bx1], dim=1), dim=1)
+    max_y1, _ = torch.max(torch.stack([ay1, by1], dim=1), dim=1)
+    min_x2, _ = torch.min(torch.stack([ax2, bx2], dim=1), dim=1)
+    min_y2, _ = torch.min(torch.stack([ay2, by2], dim=1), dim=1)
+    inter = (min_x2 > max_x1) * (min_y2 > max_y1)
+    inter = inter * (min_x2 - max_x1) * (min_y2 - max_y1)
+
+    min_x1, _ = torch.min(torch.stack([ax1, bx1], dim=1), dim=1)
+    min_y1, _ = torch.min(torch.stack([ay1, by1], dim=1), dim=1)
+    max_x2, _ = torch.max(torch.stack([ax2, bx2], dim=1), dim=1)
+    max_y2, _ = torch.max(torch.stack([ay2, by2], dim=1), dim=1)
+    cover = (max_x2 - min_x1) * (max_y2 - min_y1)
+    union = a + b - inter
+    iou = inter / (union + 1e-5)
+    giou = iou - (cover - union) / (cover + 1e-5)
+    return giou
+
+class FocalLoss(nn.Module):
+    def __init__(self, alpha=0.25, gamma=2., cls_loss_type='FL', smooth_bce_pos=0.99,
+                 smooth_bce_neg=0.01,
+                 reg_loss_type='L1',
+                 at_least_1_assgin=False,
+                 neg_iou_th=0.4,
+                 pos_iou_th=0.5,
+                 cls_weight=1.,
+                 reg_weight=1.,
+                 ):
+        super(FocalLoss, self).__init__()
+        from qd.qd_common import print_frame_info
+        print_frame_info()
+        self.iter = 0
+        self.reg_loss_type = reg_loss_type
+        self.regressBoxes = BBoxTransform()
+        if cls_loss_type == 'FL':
+            from qd.layers.loss import FocalLossWithLogitsNegLoss
+            self.cls_loss = FocalLossWithLogitsNegLoss(alpha, gamma)
+        elif cls_loss_type == 'BCE':
+            from qd.qd_pytorch import BCEWithLogitsNegLoss
+            self.cls_loss = BCEWithLogitsNegLoss(reduction='sum')
+        elif cls_loss_type == 'SmoothBCE':
+            from qd.layers.loss import SmoothBCEWithLogitsNegLoss
+            self.cls_loss = SmoothBCEWithLogitsNegLoss(
+                pos=smooth_bce_pos, neg=smooth_bce_neg)
+        elif cls_loss_type == 'SmoothFL':
+            from qd.layers.loss import FocalSmoothBCEWithLogitsNegLoss
+            self.cls_loss = FocalSmoothBCEWithLogitsNegLoss(
+                alpha=alpha, gamma=2.,
+                pos=smooth_bce_pos, neg=smooth_bce_neg)
+        else:
+            raise NotImplementedError(cls_loss_type)
+        self.at_least_1_assgin = at_least_1_assgin
+
+        self.gt_total = 0
+        self.gt_saved_by_at_least = 0
+
+        self.neg_iou_th = neg_iou_th
+        self.pos_iou_th = pos_iou_th
+
+        self.cls_weight = cls_weight
+        self.reg_weight = reg_weight
+
+        self.buf = {}
+
+    def forward(self, classifications, regressions, anchor_info, annotations, **kwargs):
+        debug = (self.iter % 100) == 0
+        self.iter += 1
+        if debug:
+            from collections import defaultdict
+            debug_info = defaultdict(list)
+
+        batch_size = classifications.shape[0]
+        classification_losses = []
+        regression_losses = []
+        anchors = anchor_info['anchor']
+        anchor = anchors[0, :, :]  # assuming all image sizes are the same, which it is
+        dtype = anchors.dtype
+
+        anchor_widths = anchor[:, 3] - anchor[:, 1]
+        anchor_heights = anchor[:, 2] - anchor[:, 0]
+        anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
+        anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
+
+        #anchor_widths = anchor[:, 2] - anchor[:, 0]
+        #anchor_heights = anchor[:, 3] - anchor[:, 1]
+        #anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
+        #anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
+        device = classifications.device
+
+        for j in range(batch_size):
+
+            classification = classifications[j, :, :]
+            regression = regressions[j, :, :]
+
+            bbox_annotation = annotations[j]
+            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
+
+            #classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
+
+            if bbox_annotation.shape[0] == 0:
+                #cls_loss = calculate_focal_loss2(classification,
+                                                 #torch.zeros(len(classification)), alpha,
+                                                #gamma)
+                #cls_loss = cls_loss.mean()
+                cls_loss = torch.tensor(0).to(dtype).to(device)
+                regression_losses.append(torch.tensor(0).to(dtype).to(device))
+                classification_losses.append(cls_loss)
+                continue
+
+            IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
+
+            IoU_max, IoU_argmax = torch.max(IoU, dim=1)
+            if self.at_least_1_assgin:
+                iou_max_gt, iou_argmax_gt = torch.max(IoU, dim=0)
+                curr_saved = (iou_max_gt < self.pos_iou_th).sum()
+                self.gt_saved_by_at_least += curr_saved
+                self.gt_total += len(iou_argmax_gt)
+                IoU_max[iou_argmax_gt] = 1.
+                IoU_argmax[iou_argmax_gt] = torch.arange(len(iou_argmax_gt)).to(device)
+
+            # compute the loss for classification
+            targets = torch.ones_like(classification) * -1
+            targets = targets.to(device)
+
+            targets[torch.lt(IoU_max, self.neg_iou_th), :] = 0
+
+            positive_indices = torch.ge(IoU_max, self.pos_iou_th)
+
+            num_positive_anchors = positive_indices.sum()
+
+            assigned_annotations = bbox_annotation[IoU_argmax, :]
+
+            targets[positive_indices, :] = 0
+            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
+
+            if debug:
+                if num_positive_anchors > 0:
+                    debug_info['pos_conf'].append(classification[
+                        positive_indices,
+                        assigned_annotations[positive_indices, 4].long()].mean())
+                debug_info['neg_conf'].append(classification[targets == 0].mean())
+                stride_idx = anchor_info['stride_idx']
+                positive_stride_idx = stride_idx[positive_indices]
+                pos_count_each_stride = torch.tensor(
+                    [(positive_stride_idx == i).sum() for i in range(5)])
+                if 'cum_pos_count_each_stride' not in self.buf:
+                    self.buf['cum_pos_count_each_stride'] = pos_count_each_stride
+                else:
+                    cum_pos_count_each_stride = self.buf['cum_pos_count_each_stride']
+                    cum_pos_count_each_stride += pos_count_each_stride
+                    self.buf['cum_pos_count_each_stride'] = cum_pos_count_each_stride
+
+            #cls_loss = calculate_focal_loss(classification, targets, alpha,
+                                            #gamma)
+            cls_loss = self.cls_loss(classification, targets)
+
+            cls_loss = cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)
+            assert cls_loss == cls_loss
+            classification_losses.append(cls_loss)
+
+            if positive_indices.sum() > 0:
+                assigned_annotations = assigned_annotations[positive_indices, :]
+                if self.reg_loss_type == 'L1':
+                    anchor_widths_pi = anchor_widths[positive_indices]
+                    anchor_heights_pi = anchor_heights[positive_indices]
+                    anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
+                    anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
+
+                    gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
+                    gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
+                    gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
+                    gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
+
+                    # efficientdet style
+                    gt_widths = torch.clamp(gt_widths, min=1)
+                    gt_heights = torch.clamp(gt_heights, min=1)
+
+                    targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
+                    targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
+                    targets_dw = torch.log(gt_widths / anchor_widths_pi)
+                    targets_dh = torch.log(gt_heights / anchor_heights_pi)
+
+                    targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
+                    targets = targets.t()
+
+                    regression_diff = torch.abs(targets - regression[positive_indices, :])
+
+                    regression_loss = torch.where(
+                        torch.le(regression_diff, 1.0 / 9.0),
+                        0.5 * 9.0 * torch.pow(regression_diff, 2),
+                        regression_diff - 0.5 / 9.0
+                    ).mean()
+                elif self.reg_loss_type == 'GIOU':
+                    curr_regression = regression[positive_indices, :]
+                    curr_anchors = anchor[positive_indices]
+                    curr_pred_xyxy = self.regressBoxes(curr_anchors,
+                                                        curr_regression)
+                    regression_loss = 1.- calculate_giou(curr_pred_xyxy, assigned_annotations)
+                    regression_loss = regression_loss.mean()
+                    assert regression_loss == regression_loss
+                else:
+                    raise NotImplementedError
+                regression_losses.append(regression_loss)
+            else:
+                if torch.cuda.is_available():
+                    regression_losses.append(torch.tensor(0).to(dtype).cuda())
+                else:
+                    regression_losses.append(torch.tensor(0).to(dtype))
+        if debug:
+            if len(debug_info) > 0:
+                logging.info('pos = {}; neg = {}, saved_ratio = {}/{}={:.1f}, '
+                             'stride_info = {}'
+                             .format(
+                                 torch.tensor(debug_info['pos_conf']).mean(),
+                                 torch.tensor(debug_info['neg_conf']).mean(),
+                                 self.gt_saved_by_at_least,
+                                 self.gt_total,
+                                 1. * self.gt_saved_by_at_least / self.gt_total,
+                                 self.buf['cum_pos_count_each_stride'],
+                             ))
+        return self.cls_weight * torch.stack(classification_losses).mean(dim=0, keepdim=True), \
+               self.reg_weight * torch.stack(regression_losses).mean(dim=0, keepdim=True)
+
+class ModelWithLoss(nn.Module):
+    def __init__(self, model, criterion):
+        super().__init__()
+        self.criterion = criterion
+        self.module = model
+
+    def forward(self, *args):
+        if len(args) == 2:
+            imgs, annotations = args
+        elif len(args) == 1:
+            imgs, annotations = args[0][:2]
+        _, regression, classification, anchors = self.module(imgs)
+        cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations)
+        return {'cls_loss': cls_loss, 'reg_loss': reg_loss}
+
+class TorchVisionNMS(nn.Module):
+    def __init__(self, iou_threshold):
+        super().__init__()
+        self.iou_threshold = iou_threshold
+
+    def forward(self, box, prob):
+        nms_idx = nms(box, prob, iou_threshold=self.iou_threshold)
+        return nms_idx
+
+class PostProcess(nn.Module):
+    def __init__(self, iou_threshold):
+        super().__init__()
+        self.nms = TorchVisionNMS(iou_threshold)
+
+    def forward(self, x, anchors, regression,
+                classification,
+                transformed_anchors, threshold, max_box):
+        all_above_th = classification > threshold
+        out = []
+        num_image = x.shape[0]
+        num_class = classification.shape[-1]
+
+        #classification = classification.cpu()
+        #transformed_anchors = transformed_anchors.cpu()
+        #all_above_th = all_above_th.cpu()
+        max_box_pre_nms = 1000
+        for i in range(num_image):
+            all_rois = []
+            all_class_ids = []
+            all_scores = []
+            for c in range(num_class):
+                above_th = all_above_th[i, :, c].nonzero()
+                if len(above_th) == 0:
+                    continue
+                above_prob = classification[i, above_th, c].squeeze(1)
+                if len(above_th) > max_box_pre_nms:
+                    _, idx = above_prob.topk(max_box_pre_nms)
+                    above_th = above_th[idx]
+                    above_prob = above_prob[idx]
+                transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1)
+                nms_idx = self.nms(transformed_anchors_per, above_prob)
+                if len(nms_idx) > 0:
+                    all_rois.append(transformed_anchors_per[nms_idx])
+                    ids = torch.tensor([c] * len(nms_idx))
+                    all_class_ids.append(ids)
+                    all_scores.append(above_prob[nms_idx])
+
+            if len(all_rois) > 0:
+                rois = torch.cat(all_rois)
+                class_ids = torch.cat(all_class_ids)
+                scores = torch.cat(all_scores)
+                if len(scores) > max_box:
+                    _, idx = torch.topk(scores, max_box)
+                    rois = rois[idx, :]
+                    class_ids = class_ids[idx]
+                    scores = scores[idx]
+                out.append({
+                    'rois': rois,
+                    'class_ids': class_ids,
+                    'scores': scores,
+                })
+            else:
+                out.append({
+                    'rois': [],
+                    'class_ids': [],
+                    'scores': [],
+                })
+
+        return out
+
+class InferenceModel(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.module = model
+
+        self.regressBoxes = BBoxTransform()
+        self.clipBoxes = ClipBoxes()
+        self.threshold = 0.01
+        self.nms_threshold = 0.5
+        self.max_box = 100
+        self.debug = False
+        self.post_process = PostProcess(self.nms_threshold)
+
+    def forward(self, sample):
+        features, regression, classification, anchor_info = self.module(sample['image'])
+        anchors = anchor_info['anchor']
+        classification = classification.sigmoid()
+        transformed_anchors = self.regressBoxes(anchors, regression)
+        transformed_anchors = self.clipBoxes(transformed_anchors, sample['image'])
+
+        preds = self.post_process(sample['image'], anchors, regression,
+                            classification, transformed_anchors,
+                            self.threshold, self.max_box)
+
+        if self.debug:
+            logging.info('debugging')
+            imgs = sample['image']
+            imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
+            imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
+            imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
+            display(preds, imgs, list(map(str, range(80))))
+
+        for p, s in zip(preds, sample['scale']):
+            if len(p['rois']) > 0:
+                p['rois'] /= s
+        return preds
+
diff --git a/maskrcnn_benchmark/modeling/backbone/efficientnet.py b/maskrcnn_benchmark/modeling/backbone/efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8a124b0e0672c723f44d75b63d2434fcfe0f52c
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/efficientnet.py
@@ -0,0 +1,691 @@
+"""
+    EfficientNet for ImageNet-1K, implemented in PyTorch.
+    Original papers:
+    - 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,' https://arxiv.org/abs/1905.11946,
+    - 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665.
+"""
+
+import os
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from maskrcnn_benchmark.layers import SEBlock, swish
+
+
+def round_channels(channels,
+                   divisor=8):
+    """
+    Round weighted channel number (make divisible operation).
+
+    Parameters:
+    ----------
+    channels : int or float
+        Original number of channels.
+    divisor : int, default 8
+        Alignment value.
+
+    Returns
+    -------
+    int
+        Weighted number of channels.
+    """
+    rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor)
+    if float(rounded_channels) < 0.9 * channels:
+        rounded_channels += divisor
+    return rounded_channels
+
+
+def calc_tf_padding(x,
+                    kernel_size,
+                    stride=1,
+                    dilation=1):
+    """
+    Calculate TF-same like padding size.
+
+    Parameters:
+    ----------
+    x : tensor
+        Input tensor.
+    kernel_size : int
+        Convolution window size.
+    stride : int, default 1
+        Strides of the convolution.
+    dilation : int, default 1
+        Dilation value for convolution layer.
+
+    Returns
+    -------
+    tuple of 4 int
+        The size of the padding.
+    """
+    height, width = x.size()[2:]
+    oh = math.ceil(height / stride)
+    ow = math.ceil(width / stride)
+    pad_h = max((oh - 1) * stride + (kernel_size - 1) * dilation + 1 - height, 0)
+    pad_w = max((ow - 1) * stride + (kernel_size - 1) * dilation + 1 - width, 0)
+    return pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2
+
+
+class ConvBlock(nn.Module):
+    """
+    Standard convolution block with Batch normalization and activation.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    kernel_size : int or tuple/list of 2 int
+        Convolution window size.
+    stride : int or tuple/list of 2 int
+        Strides of the convolution.
+    padding : int, or tuple/list of 2 int, or tuple/list of 4 int
+        Padding value for convolution layer.
+    dilation : int or tuple/list of 2 int, default 1
+        Dilation value for convolution layer.
+    groups : int, default 1
+        Number of groups.
+    bias : bool, default False
+        Whether the layer uses a bias vector.
+    use_bn : bool, default True
+        Whether to use BatchNorm layer.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    activation : function or str or None, default nn.ReLU(inplace=True)
+        Activation function or name of activation function.
+    """
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride,
+                 padding,
+                 dilation=1,
+                 groups=1,
+                 bias=False,
+                 use_bn=True,
+                 bn_eps=1e-5,
+                 activation=nn.ReLU(inplace=True)):
+        super(ConvBlock, self).__init__()
+        self.activate = (activation is not None)
+        self.use_bn = use_bn
+        self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4))
+
+        if self.use_pad:
+            self.pad = nn.ZeroPad2d(padding=padding)
+            padding = 0
+        self.conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        if self.use_bn:
+            self.bn = nn.BatchNorm2d(
+                num_features=out_channels,
+                eps=bn_eps)
+        if self.activate:
+            self.activ = activation
+
+    def forward(self, x):
+        if self.use_pad:
+            x = self.pad(x)
+        x = self.conv(x)
+        if self.use_bn:
+            x = self.bn(x)
+        if self.activate:
+            x = self.activ(x)
+        return x
+
+
+def conv1x1_block(in_channels,
+                  out_channels,
+                  stride=1,
+                  padding=0,
+                  groups=1,
+                  bias=False,
+                  use_bn=True,
+                  bn_eps=1e-5,
+                  activation=nn.ReLU(inplace=True)):
+    """
+    1x1 version of the standard convolution block.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    stride : int or tuple/list of 2 int, default 1
+        Strides of the convolution.
+    padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 0
+        Padding value for convolution layer.
+    groups : int, default 1
+        Number of groups.
+    bias : bool, default False
+        Whether the layer uses a bias vector.
+    use_bn : bool, default True
+        Whether to use BatchNorm layer.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    activation : function or str or None, default nn.ReLU(inplace=True)
+        Activation function or name of activation function.
+    """
+    return ConvBlock(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=1,
+        stride=stride,
+        padding=padding,
+        groups=groups,
+        bias=bias,
+        use_bn=use_bn,
+        bn_eps=bn_eps,
+        activation=activation)
+
+
+def conv3x3_block(in_channels,
+                  out_channels,
+                  stride=1,
+                  padding=1,
+                  dilation=1,
+                  groups=1,
+                  bias=False,
+                  use_bn=True,
+                  bn_eps=1e-5,
+                  activation=nn.ReLU(inplace=True)):
+    """
+    3x3 version of the standard convolution block.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    stride : int or tuple/list of 2 int, default 1
+        Strides of the convolution.
+    padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1
+        Padding value for convolution layer.
+    dilation : int or tuple/list of 2 int, default 1
+        Dilation value for convolution layer.
+    groups : int, default 1
+        Number of groups.
+    bias : bool, default False
+        Whether the layer uses a bias vector.
+    use_bn : bool, default True
+        Whether to use BatchNorm layer.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    activation : function or str or None, default nn.ReLU(inplace=True)
+        Activation function or name of activation function.
+    """
+    return ConvBlock(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=3,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        groups=groups,
+        bias=bias,
+        use_bn=use_bn,
+        bn_eps=bn_eps,
+        activation=activation)
+
+
+def dwconv3x3_block(in_channels,
+                    out_channels,
+                    stride=1,
+                    padding=1,
+                    dilation=1,
+                    bias=False,
+                    bn_eps=1e-5,
+                    activation=nn.ReLU(inplace=True)):
+    """
+    3x3 depthwise version of the standard convolution block.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    stride : int or tuple/list of 2 int, default 1
+        Strides of the convolution.
+    padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1
+        Padding value for convolution layer.
+    dilation : int or tuple/list of 2 int, default 1
+        Dilation value for convolution layer.
+    bias : bool, default False
+        Whether the layer uses a bias vector.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    activation : function or str or None, default nn.ReLU(inplace=True)
+        Activation function or name of activation function.
+    """
+    return ConvBlock(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=3,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        groups=out_channels,
+        bias=bias,
+        use_bn=True,
+        bn_eps=bn_eps,
+        activation=activation)
+
+
+def dwconv5x5_block(in_channels,
+                    out_channels,
+                    stride=1,
+                    padding=2,
+                    dilation=1,
+                    bias=False,
+                    bn_eps=1e-5,
+                    activation=nn.ReLU(inplace=True)):
+    """
+    5x5 depthwise version of the standard convolution block.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    stride : int or tuple/list of 2 int, default 1
+        Strides of the convolution.
+    padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 2
+        Padding value for convolution layer.
+    dilation : int or tuple/list of 2 int, default 1
+        Dilation value for convolution layer.
+    bias : bool, default False
+        Whether the layer uses a bias vector.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    activation : function or str or None, default nn.ReLU(inplace=True)
+        Activation function or name of activation function.
+    """
+    return ConvBlock(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        kernel_size=5,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        groups=out_channels,
+        bias=bias,
+        use_bn=True,
+        bn_eps=bn_eps,
+        activation=activation)
+
+
+class EffiDwsConvUnit(nn.Module):
+    """
+    EfficientNet specific depthwise separable convolution block/unit with BatchNorms and activations at each convolution
+    layers.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    stride : int or tuple/list of 2 int
+        Strides of the second convolution layer.
+    bn_eps : float
+        Small float added to variance in Batch norm.
+    activation : str
+        Name of activation function.
+    tf_mode : bool
+        Whether to use TF-like mode.
+    """
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 bn_eps,
+                 activation,
+                 tf_mode):
+        super(EffiDwsConvUnit, self).__init__()
+        self.tf_mode = tf_mode
+        self.residual = (in_channels == out_channels) and (stride == 1)
+
+        self.dw_conv = dwconv3x3_block(
+            in_channels=in_channels,
+            out_channels=in_channels,
+            padding=(0 if tf_mode else 1),
+            bn_eps=bn_eps,
+            activation=activation)
+        self.se = SEBlock(
+            channels=in_channels,
+            reduction=4,
+            mid_activation=activation)
+        self.pw_conv = conv1x1_block(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            bn_eps=bn_eps,
+            activation=None)
+
+    def forward(self, x):
+        if self.residual:
+            identity = x
+        if self.tf_mode:
+            x = F.pad(x, pad=calc_tf_padding(x, kernel_size=3))
+        x = self.dw_conv(x)
+        x = self.se(x)
+        x = self.pw_conv(x)
+        if self.residual:
+            x = x + identity
+        return x
+
+
+class EffiInvResUnit(nn.Module):
+    """
+    EfficientNet inverted residual unit.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    kernel_size : int or tuple/list of 2 int
+        Convolution window size.
+    stride : int or tuple/list of 2 int
+        Strides of the second convolution layer.
+    exp_factor : int
+        Factor for expansion of channels.
+    se_factor : int
+        SE reduction factor for each unit.
+    bn_eps : float
+        Small float added to variance in Batch norm.
+    activation : str
+        Name of activation function.
+    tf_mode : bool
+        Whether to use TF-like mode.
+    """
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride,
+                 exp_factor,
+                 se_factor,
+                 bn_eps,
+                 activation,
+                 tf_mode):
+        super(EffiInvResUnit, self).__init__()
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.tf_mode = tf_mode
+        self.residual = (in_channels == out_channels) and (stride == 1)
+        self.use_se = se_factor > 0
+        mid_channels = in_channels * exp_factor
+        dwconv_block_fn = dwconv3x3_block if kernel_size == 3 else (dwconv5x5_block if kernel_size == 5 else None)
+
+        self.conv1 = conv1x1_block(
+            in_channels=in_channels,
+            out_channels=mid_channels,
+            bn_eps=bn_eps,
+            activation=activation)
+        self.conv2 = dwconv_block_fn(
+            in_channels=mid_channels,
+            out_channels=mid_channels,
+            stride=stride,
+            padding=(0 if tf_mode else (kernel_size // 2)),
+            bn_eps=bn_eps,
+            activation=activation)
+        if self.use_se:
+            self.se = SEBlock(
+                channels=mid_channels,
+                reduction=(exp_factor * se_factor),
+                mid_activation=activation)
+        self.conv3 = conv1x1_block(
+            in_channels=mid_channels,
+            out_channels=out_channels,
+            bn_eps=bn_eps,
+            activation=None)
+
+    def forward(self, x):
+        if self.residual:
+            identity = x
+        x = self.conv1(x)
+        if self.tf_mode:
+            x = F.pad(x, pad=calc_tf_padding(x, kernel_size=self.kernel_size, stride=self.stride))
+        x = self.conv2(x)
+        if self.use_se:
+            x = self.se(x)
+        x = self.conv3(x)
+        if self.residual:
+            x = x + identity
+        return x
+
+
+class EffiInitBlock(nn.Module):
+    """
+    EfficientNet specific initial block.
+
+    Parameters:
+    ----------
+    in_channels : int
+        Number of input channels.
+    out_channels : int
+        Number of output channels.
+    bn_eps : float
+        Small float added to variance in Batch norm.
+    activation : str
+        Name of activation function.
+    tf_mode : bool
+        Whether to use TF-like mode.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 bn_eps,
+                 activation,
+                 tf_mode):
+        super(EffiInitBlock, self).__init__()
+        self.tf_mode = tf_mode
+
+        self.conv = conv3x3_block(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            stride=2,
+            padding=(0 if tf_mode else 1),
+            bn_eps=bn_eps,
+            activation=activation)
+
+    def forward(self, x):
+        if self.tf_mode:
+            x = F.pad(x, pad=calc_tf_padding(x, kernel_size=3, stride=2))
+        x = self.conv(x)
+        return x
+
+
+class EfficientNet(nn.Module):
+    """
+    EfficientNet model from 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,'
+    https://arxiv.org/abs/1905.11946.
+
+    Parameters:
+    ----------
+    channels : list of list of int
+        Number of output channels for each unit.
+    init_block_channels : int
+        Number of output channels for initial unit.
+    final_block_channels : int
+        Number of output channels for the final block of the feature extractor.
+    kernel_sizes : list of list of int
+        Number of kernel sizes for each unit.
+    strides_per_stage : list int
+        Stride value for the first unit of each stage.
+    expansion_factors : list of list of int
+        Number of expansion factors for each unit.
+    dropout_rate : float, default 0.2
+        Fraction of the input units to drop. Must be a number between 0 and 1.
+    tf_mode : bool, default False
+        Whether to use TF-like mode.
+    bn_eps : float, default 1e-5
+        Small float added to variance in Batch norm.
+    in_channels : int, default 3
+        Number of input channels.
+    in_size : tuple of two ints, default (224, 224)
+        Spatial size of the expected input image.
+    num_classes : int, default 1000
+        Number of classification classes.
+    """
+    def __init__(self,
+                 cfg,
+                 channels,
+                 init_block_channels,
+                 kernel_sizes,
+                 strides_per_stage,
+                 expansion_factors,
+                 tf_mode=False,
+                 bn_eps=1e-5,
+                 in_channels=3):
+        super(EfficientNet, self).__init__()
+        activation = swish()
+
+        self.out_channels = []
+        self.features = nn.Sequential()
+        self.stages = []
+        stem = EffiInitBlock(
+            in_channels=in_channels,
+            out_channels=init_block_channels,
+            bn_eps=bn_eps,
+            activation=activation,
+            tf_mode=tf_mode)
+        self.features.add_module("init_block", stem)
+        self.stages.append(stem)
+
+        in_channels = init_block_channels
+        for i, channels_per_stage in enumerate(channels):
+            kernel_sizes_per_stage = kernel_sizes[i]
+            expansion_factors_per_stage = expansion_factors[i]
+            stage = nn.Sequential()
+            for j, out_channels in enumerate(channels_per_stage):
+                kernel_size = kernel_sizes_per_stage[j]
+                expansion_factor = expansion_factors_per_stage[j]
+                stride = strides_per_stage[i] if (j == 0) else 1
+                if i == 0:
+                    stage.add_module("unit{}".format(j + 1), EffiDwsConvUnit(
+                        in_channels=in_channels,
+                        out_channels=out_channels,
+                        stride=stride,
+                        bn_eps=bn_eps,
+                        activation=activation,
+                        tf_mode=tf_mode))
+                else:
+                    stage.add_module("unit{}".format(j + 1), EffiInvResUnit(
+                        in_channels=in_channels,
+                        out_channels=out_channels,
+                        kernel_size=kernel_size,
+                        stride=stride,
+                        exp_factor=expansion_factor,
+                        se_factor=4,
+                        bn_eps=bn_eps,
+                        activation=activation,
+                        tf_mode=tf_mode))
+                in_channels = out_channels
+            if i>0:
+                self.out_channels.append(out_channels)
+            self.features.add_module("stage{}".format(i + 1), stage)
+            self.stages.append(stage)
+        # Optionally freeze (requires_grad=False) parts of the backbone
+        self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT)
+
+    def _freeze_backbone(self, freeze_at):
+        if freeze_at < 0:
+            return
+        for stage_index in range(freeze_at):
+            m = self.stages[stage_index]
+            for p in m.parameters():
+                p.requires_grad = False
+
+    def forward(self, x):
+        res = []
+        for i, stage in enumerate(self.stages):
+            x = stage(x)
+            if i>1:
+                res.append(x)
+        return res
+
+
+def get_efficientnet(cfg, version, tf_mode = True, bn_eps=1e-5, **kwargs):
+    if version == "b0":
+        depth_factor = 1.0
+        width_factor = 1.0
+    elif version == "b1":
+        depth_factor = 1.1
+        width_factor = 1.0
+    elif version == "b2":
+        depth_factor = 1.2
+        width_factor = 1.1
+    elif version == "b3":
+        depth_factor = 1.4
+        width_factor = 1.2
+    elif version == "b4":
+        depth_factor = 1.8
+        width_factor = 1.4
+    elif version == "b5":
+        depth_factor = 2.2
+        width_factor = 1.6
+    elif version == "b6":
+        depth_factor = 2.6
+        width_factor = 1.8
+    elif version == "b7":
+        depth_factor = 3.1
+        width_factor = 2.0
+    elif version == "b8":
+        depth_factor = 3.6
+        width_factor = 2.2
+    else:
+        raise ValueError("Unsupported EfficientNet version {}".format(version))
+
+    init_block_channels = 32
+    layers = [1, 2, 2, 3, 3, 4, 1]
+    downsample = [1, 1, 1, 1, 0, 1, 0]
+    channels_per_layers = [16, 24, 40, 80, 112, 192, 320]
+    expansion_factors_per_layers = [1, 6, 6, 6, 6, 6, 6]
+    kernel_sizes_per_layers = [3, 3, 5, 3, 5, 5, 3]
+    strides_per_stage = [1, 2, 2, 2, 1, 2, 1]
+
+    layers = [int(math.ceil(li * depth_factor)) for li in layers]
+    channels_per_layers = [round_channels(ci * width_factor) for ci in channels_per_layers]
+
+    from functools import reduce
+    channels = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
+                      zip(channels_per_layers, layers, downsample), [])
+    kernel_sizes = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
+                          zip(kernel_sizes_per_layers, layers, downsample), [])
+    expansion_factors = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
+                               zip(expansion_factors_per_layers, layers, downsample), [])
+    strides_per_stage = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
+                               zip(strides_per_stage, layers, downsample), [])
+    strides_per_stage = [si[0] for si in strides_per_stage]
+
+    init_block_channels = round_channels(init_block_channels * width_factor)
+
+    net = EfficientNet(
+        cfg,
+        channels=channels,
+        init_block_channels=init_block_channels,
+        kernel_sizes=kernel_sizes,
+        strides_per_stage=strides_per_stage,
+        expansion_factors=expansion_factors,
+        tf_mode=tf_mode,
+        bn_eps=bn_eps,
+        **kwargs)
+
+    return net
diff --git a/maskrcnn_benchmark/modeling/backbone/fbnet.py b/maskrcnn_benchmark/modeling/backbone/fbnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc2823f3bd2f06cc86b3e1bb597fb20f219817d
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/fbnet.py
@@ -0,0 +1,536 @@
+"""
+FBNet model builder
+"""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import copy
+import logging
+import math
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+from torch.nn import BatchNorm2d, SyncBatchNorm
+from maskrcnn_benchmark.layers import Conv2d, interpolate
+from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
+from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp
+
+
+logger = logging.getLogger(__name__)
+
+
+def _py2_round(x):
+    return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5)
+
+
+def _get_divisible_by(num, divisible_by, min_val):
+    ret = int(num)
+    if divisible_by > 0 and num % divisible_by != 0:
+        ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by)
+    return ret
+
+
+class Identity(nn.Module):
+    def __init__(self, C_in, C_out, stride):
+        super(Identity, self).__init__()
+        self.conv = (
+            ConvBNRelu(
+                C_in,
+                C_out,
+                kernel=1,
+                stride=stride,
+                pad=0,
+                no_bias=1,
+                use_relu="relu",
+                bn_type="bn",
+            )
+            if C_in != C_out or stride != 1
+            else None
+        )
+
+    def forward(self, x):
+        if self.conv:
+            out = self.conv(x)
+        else:
+            out = x
+        return out
+
+
+class CascadeConv3x3(nn.Sequential):
+    def __init__(self, C_in, C_out, stride):
+        assert stride in [1, 2]
+        ops = [
+            Conv2d(C_in, C_in, 3, stride, 1, bias=False),
+            BatchNorm2d(C_in),
+            nn.ReLU(inplace=True),
+            Conv2d(C_in, C_out, 3, 1, 1, bias=False),
+            BatchNorm2d(C_out),
+        ]
+        super(CascadeConv3x3, self).__init__(*ops)
+        self.res_connect = (stride == 1) and (C_in == C_out)
+
+    def forward(self, x):
+        y = super(CascadeConv3x3, self).forward(x)
+        if self.res_connect:
+            y += x
+        return y
+
+
+class Shift(nn.Module):
+    def __init__(self, C, kernel_size, stride, padding):
+        super(Shift, self).__init__()
+        self.C = C
+        kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32)
+        ch_idx = 0
+
+        assert stride in [1, 2]
+        self.stride = stride
+        self.padding = padding
+        self.kernel_size = kernel_size
+        self.dilation = 1
+
+        hks = kernel_size // 2
+        ksq = kernel_size ** 2
+
+        for i in range(kernel_size):
+            for j in range(kernel_size):
+                if i == hks and j == hks:
+                    num_ch = C // ksq + C % ksq
+                else:
+                    num_ch = C // ksq
+                kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1
+                ch_idx += num_ch
+
+        self.register_parameter("bias", None)
+        self.kernel = nn.Parameter(kernel, requires_grad=False)
+
+    def forward(self, x):
+        if x.numel() > 0:
+            return nn.functional.conv2d(
+                x,
+                self.kernel,
+                self.bias,
+                (self.stride, self.stride),
+                (self.padding, self.padding),
+                self.dilation,
+                self.C,  # groups
+            )
+
+        output_shape = [
+            (i + 2 * p - (di * (k - 1) + 1)) // d + 1
+            for i, p, di, k, d in zip(
+                x.shape[-2:],
+                (self.padding, self.dilation),
+                (self.dilation, self.dilation),
+                (self.kernel_size, self.kernel_size),
+                (self.stride, self.stride),
+            )
+        ]
+        output_shape = [x.shape[0], self.C] + output_shape
+        return _NewEmptyTensorOp.apply(x, output_shape)
+
+
+class ShiftBlock5x5(nn.Sequential):
+    def __init__(self, C_in, C_out, expansion, stride):
+        assert stride in [1, 2]
+        self.res_connect = (stride == 1) and (C_in == C_out)
+
+        C_mid = _get_divisible_by(C_in * expansion, 8, 8)
+
+        ops = [
+            # pw
+            Conv2d(C_in, C_mid, 1, 1, 0, bias=False),
+            BatchNorm2d(C_mid),
+            nn.ReLU(inplace=True),
+            # shift
+            Shift(C_mid, 5, stride, 2),
+            # pw-linear
+            Conv2d(C_mid, C_out, 1, 1, 0, bias=False),
+            BatchNorm2d(C_out),
+        ]
+        super(ShiftBlock5x5, self).__init__(*ops)
+
+    def forward(self, x):
+        y = super(ShiftBlock5x5, self).forward(x)
+        if self.res_connect:
+            y += x
+        return y
+
+
+class ChannelShuffle(nn.Module):
+    def __init__(self, groups):
+        super(ChannelShuffle, self).__init__()
+        self.groups = groups
+
+    def forward(self, x):
+        """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
+        N, C, H, W = x.size()
+        g = self.groups
+        assert C % g == 0, "Incompatible group size {} for input channel {}".format(
+            g, C
+        )
+        return (
+            x.view(N, g, int(C / g), H, W)
+            .permute(0, 2, 1, 3, 4)
+            .contiguous()
+            .view(N, C, H, W)
+        )
+
+
+class ConvBNRelu(nn.Sequential):
+    def __init__(
+        self,
+        input_depth,
+        output_depth,
+        kernel,
+        stride,
+        pad,
+        no_bias,
+        use_relu,
+        bn_type,
+        group=1,
+        *args,
+        **kwargs
+    ):
+        super(ConvBNRelu, self).__init__()
+
+        assert use_relu in ["relu", None]
+        if isinstance(bn_type, (list, tuple)):
+            assert len(bn_type) == 2
+            assert bn_type[0] == "gn"
+            gn_group = bn_type[1]
+            bn_type = bn_type[0]
+        assert bn_type in ["bn", "nsbn", "sbn", "af", "gn", None]
+        assert stride in [1, 2, 4]
+
+        op = Conv2d(
+            input_depth,
+            output_depth,
+            kernel_size=kernel,
+            stride=stride,
+            padding=pad,
+            bias=not no_bias,
+            groups=group,
+            *args,
+            **kwargs
+        )
+        nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu")
+        if op.bias is not None:
+            nn.init.constant_(op.bias, 0.0)
+        self.add_module("conv", op)
+
+        if bn_type == "bn":
+            bn_op = BatchNorm2d(output_depth)
+        elif bn_type == "sbn":
+            bn_op = SyncBatchNorm(output_depth)
+        elif bn_type == "nsbn":
+            bn_op = NaiveSyncBatchNorm2d(output_depth)
+        elif bn_type == "gn":
+            bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth)
+        elif bn_type == "af":
+            bn_op = FrozenBatchNorm2d(output_depth)
+        if bn_type is not None:
+            self.add_module("bn", bn_op)
+
+        if use_relu == "relu":
+            self.add_module("relu", nn.ReLU(inplace=True))
+
+
+class SEModule(nn.Module):
+    reduction = 4
+
+    def __init__(self, C):
+        super(SEModule, self).__init__()
+        mid = max(C // self.reduction, 8)
+        conv1 = Conv2d(C, mid, 1, 1, 0)
+        conv2 = Conv2d(mid, C, 1, 1, 0)
+
+        self.op = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        return x * self.op(x)
+
+
+class Upsample(nn.Module):
+    def __init__(self, scale_factor, mode, align_corners=None):
+        super(Upsample, self).__init__()
+        self.scale = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        return interpolate(
+            x, scale_factor=self.scale, mode=self.mode,
+            align_corners=self.align_corners
+        )
+
+
+def _get_upsample_op(stride):
+    assert (
+        stride in [1, 2, 4]
+        or stride in [-1, -2, -4]
+        or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride))
+    )
+
+    scales = stride
+    ret = None
+    if isinstance(stride, tuple) or stride < 0:
+        scales = [-x for x in stride] if isinstance(stride, tuple) else -stride
+        stride = 1
+        ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None)
+
+    return ret, stride
+
+
+class IRFBlock(nn.Module):
+    def __init__(
+        self,
+        input_depth,
+        output_depth,
+        expansion,
+        stride,
+        bn_type="bn",
+        kernel=3,
+        width_divisor=1,
+        shuffle_type=None,
+        pw_group=1,
+        se=False,
+        cdw=False,
+        dw_skip_bn=False,
+        dw_skip_relu=False,
+    ):
+        super(IRFBlock, self).__init__()
+
+        assert kernel in [1, 3, 5, 7], kernel
+
+        self.use_res_connect = stride == 1 and input_depth == output_depth
+        self.output_depth = output_depth
+
+        mid_depth = int(input_depth * expansion)
+        mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor)
+
+        # pw
+        self.pw = ConvBNRelu(
+            input_depth,
+            mid_depth,
+            kernel=1,
+            stride=1,
+            pad=0,
+            no_bias=1,
+            use_relu="relu",
+            bn_type=bn_type,
+            group=pw_group,
+        )
+
+        # negative stride to do upsampling
+        self.upscale, stride = _get_upsample_op(stride)
+
+        # dw
+        if kernel == 1:
+            self.dw = nn.Sequential()
+        elif cdw:
+            dw1 = ConvBNRelu(
+                mid_depth,
+                mid_depth,
+                kernel=kernel,
+                stride=stride,
+                pad=(kernel // 2),
+                group=mid_depth,
+                no_bias=1,
+                use_relu="relu",
+                bn_type=bn_type,
+            )
+            dw2 = ConvBNRelu(
+                mid_depth,
+                mid_depth,
+                kernel=kernel,
+                stride=1,
+                pad=(kernel // 2),
+                group=mid_depth,
+                no_bias=1,
+                use_relu="relu" if not dw_skip_relu else None,
+                bn_type=bn_type if not dw_skip_bn else None,
+            )
+            self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)]))
+        else:
+            self.dw = ConvBNRelu(
+                mid_depth,
+                mid_depth,
+                kernel=kernel,
+                stride=stride,
+                pad=(kernel // 2),
+                group=mid_depth,
+                no_bias=1,
+                use_relu="relu" if not dw_skip_relu else None,
+                bn_type=bn_type if not dw_skip_bn else None,
+            )
+
+        # pw-linear
+        self.pwl = ConvBNRelu(
+            mid_depth,
+            output_depth,
+            kernel=1,
+            stride=1,
+            pad=0,
+            no_bias=1,
+            use_relu=None,
+            bn_type=bn_type,
+            group=pw_group,
+        )
+
+        self.shuffle_type = shuffle_type
+        if shuffle_type is not None:
+            self.shuffle = ChannelShuffle(pw_group)
+
+        self.se4 = SEModule(output_depth) if se else nn.Sequential()
+
+        self.output_depth = output_depth
+
+    def forward(self, x):
+        y = self.pw(x)
+        if self.shuffle_type == "mid":
+            y = self.shuffle(y)
+        if self.upscale is not None:
+            y = self.upscale(y)
+        y = self.dw(y)
+        y = self.pwl(y)
+        if self.use_res_connect:
+            y += x
+        y = self.se4(y)
+        return y
+
+
+
+skip = lambda C_in, C_out, stride, **kwargs: Identity(
+    C_in, C_out, stride
+)
+basic_block = lambda C_in, C_out, stride, **kwargs: CascadeConv3x3(
+    C_in, C_out, stride
+)
+# layer search 2
+ir_k3_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=3, **kwargs
+)
+ir_k3_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=3, **kwargs
+)
+ir_k3_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=3, **kwargs
+)
+ir_k3_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs
+)
+ir_k5_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=5, **kwargs
+)
+ir_k5_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=5, **kwargs
+)
+ir_k5_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=5, **kwargs
+)
+ir_k5_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs
+)
+# layer search se
+ir_k3_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=3, se=True, **kwargs
+)
+ir_k3_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=3, se=True, **kwargs
+)
+ir_k3_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=3, se=True, **kwargs
+)
+ir_k3_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in,
+    C_out,
+    4,
+    stride,
+    kernel=3,
+    shuffle_type=mid,
+    pw_group=4,
+    se=True,
+    **kwargs
+)
+ir_k5_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=5, se=True, **kwargs
+)
+ir_k5_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=5, se=True, **kwargs
+)
+ir_k5_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=5, se=True, **kwargs
+)
+ir_k5_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in,
+    C_out,
+    4,
+    stride,
+    kernel=5,
+    shuffle_type="mid",
+    pw_group=4,
+    se=True,
+    **kwargs
+)
+# layer search 3 (in addition to layer search 2)
+ir_k3_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs
+)
+ir_k5_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs
+)
+ir_k3_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in,
+    C_out,
+    1,
+    stride,
+    kernel=3,
+    shuffle_type="mid",
+    pw_group=2,
+    se=True,
+    **kwargs
+)
+ir_k5_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in,
+    C_out,
+    1,
+    stride,
+    kernel=5,
+    shuffle_type="mid",
+    pw_group=2,
+    se=True,
+    **kwargs
+)
+# layer search 4 (in addition to layer search 3)
+ir_k33_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs
+)
+ir_k33_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs
+)
+ir_k33_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs
+)
+# layer search 5 (in addition to layer search 4)
+ir_k7_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=7, **kwargs
+)
+ir_k7_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=7, **kwargs
+)
+ir_k7_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=7, **kwargs
+)
+ir_k7_sep_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs
+)
+ir_k7_sep_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs
+)
+ir_k7_sep_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
+    C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs
+)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/fpn.py b/maskrcnn_benchmark/modeling/backbone/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bd853325190618d82addd46ac0d08f44742aa7
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/fpn.py
@@ -0,0 +1,167 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class FPN(nn.Module):
+    """
+    Module that adds FPN on top of a list of feature maps.
+    The feature maps are currently supposed to be in increasing depth
+    order, and must be consecutive
+    """
+
+    def __init__(
+        self, in_channels_list, out_channels, conv_block, top_blocks=None, drop_block=None, use_spp=False, use_pan=False,
+            return_swint_feature_before_fusion=False
+    ):
+        """
+        Arguments:
+            in_channels_list (list[int]): number of channels for each feature map that
+                will be fed
+            out_channels (int): number of channels of the FPN representation
+            top_blocks (nn.Module or None): if provided, an extra operation will
+                be performed on the output of the last (smallest resolution)
+                FPN output, and the result will extend the result list
+        """
+        super(FPN, self).__init__()
+        self.inner_blocks = []
+        self.layer_blocks = []
+        self.pan_blocks = [] if use_pan else None
+        self.spp_block = SPPLayer() if use_spp else None
+        self.return_swint_feature_before_fusion = return_swint_feature_before_fusion
+        for idx, in_channels in enumerate(in_channels_list, 1):
+            inner_block = "fpn_inner{}".format(idx)
+            layer_block = "fpn_layer{}".format(idx)
+
+            if in_channels == 0:
+                continue
+            if idx==len(in_channels_list) and use_spp:
+                in_channels = in_channels*4
+            inner_block_module = conv_block(in_channels, out_channels, 1)
+            layer_block_module = conv_block(out_channels, out_channels, 3, 1)
+            self.add_module(inner_block, inner_block_module)
+            self.add_module(layer_block, layer_block_module)
+            self.inner_blocks.append(inner_block)
+            self.layer_blocks.append(layer_block)
+
+            if use_pan:
+                pan_in_block = "pan_in_layer{}".format(idx)
+                pan_in_block_module = conv_block(out_channels, out_channels, 3, 2)
+                self.add_module(pan_in_block, pan_in_block_module)
+                pan_out_block = "pan_out_layer{}".format(idx)
+                pan_out_block_module = conv_block(out_channels, out_channels, 3, 1)
+                self.add_module(pan_out_block, pan_out_block_module)
+                self.pan_blocks.append([pan_in_block, pan_out_block])
+
+        self.top_blocks = top_blocks
+        self.drop_block = drop_block
+
+    def forward(self, x):
+        """
+        Arguments:
+            x (list[Tensor]): feature maps for each feature level.
+        Returns:
+            results (tuple[Tensor]): feature maps after FPN layers.
+                They are ordered from highest resolution first.
+        """
+        if type(x) is tuple:
+            # for the case of VL backbone
+            x, x_text = x[0], x[1]
+        # print([v.shape for v in x])
+        swint_feature_c4 = None
+        if self.return_swint_feature_before_fusion:
+            # TODO: here we only return last single scale feature map before the backbone fusion, should be more flexible
+            swint_feature_c4 = x[-2]
+
+        if self.spp_block:
+            last_inner = getattr(self, self.inner_blocks[-1])(self.spp_block(x[-1]))
+        else:
+            last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
+        results = []
+        results.append(getattr(self, self.layer_blocks[-1])(last_inner))
+        for feature, inner_block, layer_block in zip(
+            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
+        ):
+            if not inner_block:
+                continue
+            inner_lateral = getattr(self, inner_block)(feature)
+
+            if inner_lateral.shape[-2:] != last_inner.shape[-2:]:
+                # TODO: could also give size instead of
+                inner_top_down = F.interpolate(last_inner, size=inner_lateral.shape[-2:], mode="nearest")
+            else:
+                inner_top_down = last_inner
+
+            # TODO use size instead of scale to make it robust to different sizes
+            # inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:],
+            # mode='bilinear', align_corners=False)
+            last_inner = inner_lateral + inner_top_down
+            if self.drop_block and self.training:
+                results.insert(0, getattr(self, layer_block)(self.drop_block(last_inner)))
+            else:
+                results.insert(0, getattr(self, layer_block)(last_inner))
+
+        if self.pan_blocks:
+            pan_results = []
+            last_outer = results[0]
+            pan_results.append(last_outer)
+            for outer_top_down, pan_block in zip(results[1:], self.pan_blocks):
+
+                if self.drop_block and self.training:
+                    pan_lateral = getattr(self, pan_block[0])(self.drop_block(last_outer))
+                else:
+                    pan_lateral = getattr(self, pan_block[0])(last_outer)
+
+                last_outer = getattr(self, pan_block[1])(pan_lateral + outer_top_down)
+                pan_results.append(last_outer)
+            results = pan_results
+
+        if isinstance(self.top_blocks, LastLevelP6P7):
+            last_results = self.top_blocks(x[-1], results[-1])
+            results.extend(last_results)
+        elif isinstance(self.top_blocks, LastLevelMaxPool):
+            last_results = self.top_blocks(results[-1])
+            results.extend(last_results)
+
+        try:
+            return tuple(results), x_text, swint_feature_c4
+        except NameError as e:
+            return tuple(results)
+
+
+class LastLevelMaxPool(nn.Module):
+    def forward(self, x):
+        return [F.max_pool2d(x, 1, 2, 0)]
+
+
+class LastLevelP6P7(nn.Module):
+    """
+    This module is used in RetinaNet to generate extra layers, P6 and P7.
+    """
+    def __init__(self, in_channels, out_channels):
+        super(LastLevelP6P7, self).__init__()
+        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
+        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
+        for module in [self.p6, self.p7]:
+            nn.init.kaiming_uniform_(module.weight, a=1)
+            nn.init.constant_(module.bias, 0)
+        self.use_P5 = in_channels == out_channels
+
+    def forward(self, c5, p5):
+        x = p5 if self.use_P5 else c5
+        p6 = self.p6(x)
+        p7 = self.p7(F.relu(p6))
+        return [p6, p7]
+
+
+class SPPLayer(nn.Module):
+    def __init__(self):
+        super(SPPLayer, self).__init__()
+
+    def forward(self, x):
+        x_1 = x
+        x_2 = F.max_pool2d(x, 5, stride=1, padding=2)
+        x_3 = F.max_pool2d(x, 9, stride=1, padding=4)
+        x_4 = F.max_pool2d(x, 13, stride=1, padding=6)
+        out = torch.cat((x_1, x_2, x_3, x_4),dim=1)
+        return out
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/mixer.py b/maskrcnn_benchmark/modeling/backbone/mixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4782d50863a4da9070285a9cd3093db4fbcf6f8
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/mixer.py
@@ -0,0 +1,23 @@
+import torch
+from torch import nn
+
+class MixedOperationRandom(nn.Module):
+    def __init__(self, search_ops):
+        super(MixedOperationRandom, self).__init__()
+        self.ops = nn.ModuleList(search_ops)
+        self.num_ops = len(search_ops)
+
+    def forward(self, x, x_path=None):
+        if x_path is None:
+            output = sum(op(x) for op in self.ops) / self.num_ops
+        else:
+            assert isinstance(x_path, (int, float)) and 0 <= x_path < self.num_ops or isinstance(x_path, torch.Tensor)
+            if isinstance(x_path, (int, float)):
+                x_path = int(x_path)
+                assert 0 <= x_path < self.num_ops
+                output = self.ops[x_path](x)
+            elif isinstance(x_path, torch.Tensor):
+                assert x_path.size(0) == x.size(0), 'batch_size should match length of y_idx'
+                output = torch.cat([self.ops[int(x_path[i].item())](x.narrow(0, i, 1))
+                                    for i in range(x.size(0))], dim=0)
+        return output
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/ops.py b/maskrcnn_benchmark/modeling/backbone/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c36ccebb57d3207d97e32babd84e21b65a2ec2
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/ops.py
@@ -0,0 +1,71 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv7x7(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """7x7 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
+                     padding=3*dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """5x5 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
+                     padding=2*dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+def maxpool(**kwargs):
+    return nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+
+def avgpool(**kwargs):
+    return nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+
+def dropout(prob):
+    return nn.Dropout(prob)
+
+
+conv3x3sep = lambda i, o, s=1: conv3x3(i, o, s, groups=i)
+conv3x3g2 = lambda i, o, s=1: conv3x3(i, o, s, groups=2)
+conv3x3g4 = lambda i, o, s=1: conv3x3(i, o, s, groups=4)
+conv3x3g8 = lambda i, o, s=1: conv3x3(i, o, s, groups=8)
+conv3x3dw = lambda i, o, s=1: conv3x3(i, o, s, groups=i)
+
+conv3x3d2 = lambda i, o, s=1: conv3x3(i, o, s, dilation=2)
+conv3x3d3 = lambda i, o, s=1: conv3x3(i, o, s, dilation=3)
+conv3x3d4 = lambda i, o, s=1: conv3x3(i, o, s, dilation=4)
+
+
+conv5x5sep = lambda i, o, s=1: conv5x5(i, o, s, groups=i)
+conv5x5g2 = lambda i, o, s=1: conv5x5(i, o, s, groups=2)
+conv5x5g4 = lambda i, o, s=1: conv5x5(i, o, s, groups=4)
+conv5x5g8 = lambda i, o, s=1: conv5x5(i, o, s, groups=8)
+conv5x5dw = lambda i, o, s=1: conv5x5(i, o, s, groups=i)
+
+
+conv5x5d2 = lambda i, o, s=1: conv5x5(i, o, s, dilation=2)
+conv5x5d3 = lambda i, o, s=1: conv5x5(i, o, s, dilation=3)
+conv5x5d4 = lambda i, o, s=1: conv5x5(i, o, s, dilation=4)
+
+conv7x7sep = lambda i, o, s=1: conv7x7(i, o, s, groups=i)
+conv7x7g2 = lambda i, o, s=1: conv7x7(i, o, s, groups=2)
+conv7x7g4 = lambda i, o, s=1: conv7x7(i, o, s, groups=4)
+conv7x7g8 = lambda i, o, s=1: conv7x7(i, o, s, groups=8)
+conv7x7dw = lambda i, o, s=1: conv7x7(i, o, s, groups=i)
+
+conv7x7d2 = lambda i, o, s=1: conv7x7(i, o, s, dilation=2)
+conv7x7d3 = lambda i, o, s=1: conv7x7(i, o, s, dilation=3)
+conv7x7d4 = lambda i, o, s=1: conv7x7(i, o, s, dilation=4)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/resnet.py b/maskrcnn_benchmark/modeling/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f27a1edc470a0c49778369abceeefbeecc792f85
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/resnet.py
@@ -0,0 +1,643 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Variant of the resnet module that takes cfg as an argument.
+Example usage. Strings may be specified in the config file.
+    model = ResNet(
+        "StemWithFixedBatchNorm",
+        "BottleneckWithFixedBatchNorm",
+        "ResNet50StagesTo4",
+    )
+OR:
+    model = ResNet(
+        "StemWithGN",
+        "BottleneckWithGN",
+        "ResNet50StagesTo4",
+    )
+Custom implementations may be written in user code and hooked in via the
+`register_*` functions.
+"""
+from collections import namedtuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import BatchNorm2d, SyncBatchNorm
+
+from maskrcnn_benchmark.layers import FrozenBatchNorm2d, NaiveSyncBatchNorm2d
+from maskrcnn_benchmark.layers import Conv2d, DFConv2d, SELayer
+from maskrcnn_benchmark.modeling.make_layers import group_norm
+from maskrcnn_benchmark.utils.registry import Registry
+
+
+# ResNet stage specification
+StageSpec = namedtuple(
+    "StageSpec",
+    [
+        "index",  # Index of the stage, eg 1, 2, ..,. 5
+        "block_count",  # Number of residual blocks in the stage
+        "return_features",  # True => return the last feature map from this stage
+    ],
+)
+
+# -----------------------------------------------------------------------------
+# Standard ResNet models
+# -----------------------------------------------------------------------------
+# ResNet-50 (including all stages)
+ResNet50StagesTo5 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True))
+)
+# ResNet-50 up to stage 4 (excludes stage 5)
+ResNet50StagesTo4 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True))
+)
+# ResNet-101 (including all stages)
+ResNet101StagesTo5 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True))
+)
+# ResNet-101 up to stage 4 (excludes stage 5)
+ResNet101StagesTo4 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True))
+)
+# ResNet-50-FPN (including all stages)
+ResNet50FPNStagesTo5 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True))
+)
+# ResNet-101-FPN (including all stages)
+ResNet101FPNStagesTo5 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True))
+)
+# ResNet-152-FPN (including all stages)
+ResNet152FPNStagesTo5 = tuple(
+    StageSpec(index=i, block_count=c, return_features=r)
+    for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True))
+)
+
+class ResNet(nn.Module):
+    def __init__(self, cfg):
+        super(ResNet, self).__init__()
+
+        # If we want to use the cfg in forward(), then we should make a copy
+        # of it and store it for later use:
+        # self.cfg = cfg.clone()
+
+        # Translate string names to implementations
+        norm_level = None
+        stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC]
+        stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY]
+        transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC]
+
+        if cfg.MODEL.BACKBONE.USE_BN:
+            stem_module = StemWithBatchNorm
+            transformation_module = BottleneckWithBatchNorm
+            norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL
+        elif cfg.MODEL.BACKBONE.USE_NSYNCBN:
+            stem_module = StemWithNaiveSyncBatchNorm
+            transformation_module = BottleneckWithNaiveSyncBatchNorm
+            norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL
+        elif cfg.MODEL.BACKBONE.USE_SYNCBN:
+            stem_module = StemWithSyncBatchNorm
+            transformation_module = BottleneckWithSyncBatchNorm
+            norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL
+
+        # Construct the stem module
+        self.stem = stem_module(cfg)
+
+        # Constuct the specified ResNet stages
+        num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
+        width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
+        in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
+        stage2_bottleneck_channels = num_groups * width_per_group
+        stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
+        with_se = cfg.MODEL.RESNETS.WITH_SE
+
+        self.stages = []
+        self.out_channels = []
+        self.return_features = {}
+        for stage_spec in stage_specs:
+            name = "layer" + str(stage_spec.index)
+            stage2_relative_factor = 2 ** (stage_spec.index - 1)
+            bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
+            out_channels = stage2_out_channels * stage2_relative_factor
+            stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index - 1]
+            if cfg.MODEL.RESNETS.USE_AVG_DOWN:
+                avg_down_stride = 1 if stage_spec.index==1 else 2
+            else:
+                avg_down_stride = 0
+            module = _make_stage(
+                transformation_module,
+                in_channels,
+                bottleneck_channels,
+                out_channels,
+                stage_spec.block_count,
+                num_groups,
+                cfg.MODEL.RESNETS.STRIDE_IN_1X1,
+                first_stride=int(stage_spec.index > 1) + 1,
+                dcn_config={
+                    "stage_with_dcn": stage_with_dcn,
+                    "with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN,
+                    "deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS,
+                },
+                norm_level=norm_level,
+                with_se=with_se,
+                avg_down_stride=avg_down_stride
+            )
+            in_channels = out_channels
+            self.add_module(name, module)
+            self.stages.append(name)
+            self.out_channels.append(out_channels)
+            self.return_features[name] = stage_spec.return_features
+
+        # Optionally freeze (requires_grad=False) parts of the backbone
+        self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT)
+
+    def _freeze_backbone(self, freeze_at):
+        if freeze_at < 0:
+            return
+        for stage_index in range(freeze_at):
+            if stage_index == 0:
+                m = self.stem  # stage 0 is the stem
+            else:
+                m = getattr(self, "layer" + str(stage_index))
+            for p in m.parameters():
+                p.requires_grad = False
+
+    def forward(self, x):
+        outputs = []
+        x = self.stem(x)
+        for stage_name in self.stages:
+            x = getattr(self, stage_name)(x)
+            if self.return_features[stage_name]:
+                outputs.append(x)
+        return outputs
+
+
+class ResNetHead(nn.Module):
+    def __init__(
+        self,
+        block_module,
+        stages,
+        num_groups=1,
+        width_per_group=64,
+        stride_in_1x1=True,
+        stride_init=None,
+        res2_out_channels=256,
+        dilation=1,
+        dcn_config=None
+    ):
+        super(ResNetHead, self).__init__()
+
+        stage2_relative_factor = 2 ** (stages[0].index - 1)
+        stage2_bottleneck_channels = num_groups * width_per_group
+        out_channels = res2_out_channels * stage2_relative_factor
+        in_channels = out_channels // 2
+        bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
+
+        block_module = _TRANSFORMATION_MODULES[block_module]
+
+        self.stages = []
+        stride = stride_init
+        for stage in stages:
+            name = "layer" + str(stage.index)
+            if not stride:
+                stride = int(stage.index > 1) + 1
+            module = _make_stage(
+                block_module,
+                in_channels,
+                bottleneck_channels,
+                out_channels,
+                stage.block_count,
+                num_groups,
+                stride_in_1x1,
+                first_stride=stride,
+                dilation=dilation,
+                dcn_config=dcn_config
+            )
+            stride = None
+            self.add_module(name, module)
+            self.stages.append(name)
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        for stage in self.stages:
+            x = getattr(self, stage)(x)
+        return x
+
+
+def _make_stage(
+    transformation_module,
+    in_channels,
+    bottleneck_channels,
+    out_channels,
+    block_count,
+    num_groups,
+    stride_in_1x1,
+    first_stride,
+    dilation=1,
+    dcn_config=None,
+    norm_level=None,
+    **kwargs
+):
+    blocks = []
+    stride = first_stride
+    for li in range(block_count):
+        if norm_level is not None:
+            layer_module = BottleneckWithFixedBatchNorm
+            if norm_level >= 1 and li == 0:
+                layer_module = transformation_module
+            if norm_level >= 2 and li == block_count - 1:
+                layer_module = transformation_module
+            if norm_level >= 3:
+                layer_module = transformation_module
+        else:
+            layer_module = transformation_module
+
+        blocks.append(
+            layer_module(
+                in_channels,
+                bottleneck_channels,
+                out_channels,
+                num_groups,
+                stride_in_1x1,
+                stride,
+                dilation=dilation,
+                dcn_config=dcn_config,
+                **kwargs
+            )
+        )
+        stride = 1
+        in_channels = out_channels
+    return nn.Sequential(*blocks)
+
+
+class Bottleneck(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups,
+        stride_in_1x1,
+        stride,
+        dilation,
+        norm_func,
+        dcn_config,
+        with_se=False,
+        avg_down_stride=0,
+    ):
+        super(Bottleneck, self).__init__()
+
+        self.downsample = None
+        if in_channels != out_channels:
+            down_stride = stride if dilation == 1 else 1
+            if avg_down_stride>0:
+                self.downsample = nn.Sequential(
+                    nn.AvgPool2d(
+                        kernel_size=avg_down_stride,
+                        stride=avg_down_stride,
+                        ceil_mode=True,
+                        count_include_pad=False
+                    ),
+                    nn.Conv2d(
+                        in_channels, out_channels,
+                        kernel_size=1, stride=1, bias=False
+                    ),
+                    norm_func(out_channels),
+                )
+            else:
+                self.downsample = nn.Sequential(
+                    Conv2d(
+                        in_channels, out_channels,
+                        kernel_size=1, stride=down_stride, bias=False
+                    ),
+                    norm_func(out_channels),
+                )
+            for modules in [self.downsample,]:
+                for l in modules.modules():
+                    if isinstance(l, Conv2d):
+                        nn.init.kaiming_uniform_(l.weight, a=1)
+
+        if dilation > 1:
+            stride = 1 # reset to be 1
+
+        # The original MSRA ResNet models have stride in the first 1x1 conv
+        # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
+        # stride in the 3x3 conv
+        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
+
+        self.conv1 = Conv2d(
+            in_channels,
+            bottleneck_channels,
+            kernel_size=1,
+            stride=stride_1x1,
+            bias=False,
+        )
+        self.bn1 = norm_func(bottleneck_channels)
+        # TODO: specify init for the above
+        with_dcn = dcn_config.get("stage_with_dcn", False)
+        if with_dcn:
+            deformable_groups = dcn_config.get("deformable_groups", 1)
+            with_modulated_dcn = dcn_config.get("with_modulated_dcn", False)
+            self.conv2 = DFConv2d(
+                bottleneck_channels,
+                bottleneck_channels,
+                with_modulated_dcn=with_modulated_dcn,
+                kernel_size=3,
+                stride=stride_3x3,
+                groups=num_groups,
+                dilation=dilation,
+                deformable_groups=deformable_groups,
+                bias=False
+            )
+        else:
+            self.conv2 = Conv2d(
+                bottleneck_channels,
+                bottleneck_channels,
+                kernel_size=3,
+                stride=stride_3x3,
+                padding=dilation,
+                bias=False,
+                groups=num_groups,
+                dilation=dilation
+            )
+            nn.init.kaiming_uniform_(self.conv2.weight, a=1)
+
+        self.bn2 = norm_func(bottleneck_channels)
+
+        self.conv3 = Conv2d(
+            bottleneck_channels, out_channels, kernel_size=1, bias=False
+        )
+        self.bn3 = norm_func(out_channels)
+
+        self.se = SELayer(out_channels) if with_se and not with_dcn else None
+
+        for l in [self.conv1, self.conv3,]:
+            nn.init.kaiming_uniform_(l.weight, a=1)
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = F.relu_(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = F.relu_(out)
+
+        out0 = self.conv3(out)
+        out = self.bn3(out0)
+
+        if self.se:
+            out = self.se(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = F.relu_(out)
+
+        return out
+
+
+class BaseStem(nn.Module):
+    def __init__(self, cfg, norm_func):
+        super(BaseStem, self).__init__()
+
+        out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
+        self.stem_3x3 = cfg.MODEL.RESNETS.USE_STEM3X3
+
+        if self.stem_3x3:
+            self.conv1 = Conv2d(
+                3, out_channels, kernel_size=3, stride=2, padding=1, bias=False
+            )
+            self.bn1 = norm_func(out_channels)
+            self.conv2 = Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False
+            )
+            self.bn2 = norm_func(out_channels)
+            for l in [self.conv1, self.conv2]:
+                nn.init.kaiming_uniform_(l.weight, a=1)
+        else:
+            self.conv1 = Conv2d(
+                3, out_channels, kernel_size=7, stride=2, padding=3, bias=False
+            )
+            self.bn1 = norm_func(out_channels)
+
+            for l in [self.conv1,]:
+                nn.init.kaiming_uniform_(l.weight, a=1)
+
+    def forward(self, x):
+        if self.stem_3x3:
+            x = self.conv1(x)
+            x = self.bn1(x)
+            x = F.relu_(x)
+            x = self.conv2(x)
+            x = self.bn2(x)
+            x = F.relu_(x)
+        else:
+            x = self.conv1(x)
+            x = self.bn1(x)
+            x = F.relu_(x)
+            x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+        return x
+
+
+class BottleneckWithFixedBatchNorm(Bottleneck):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups=1,
+        stride_in_1x1=True,
+        stride=1,
+        dilation=1,
+        dcn_config=None,
+        **kwargs
+    ):
+        super(BottleneckWithFixedBatchNorm, self).__init__(
+            in_channels=in_channels,
+            bottleneck_channels=bottleneck_channels,
+            out_channels=out_channels,
+            num_groups=num_groups,
+            stride_in_1x1=stride_in_1x1,
+            stride=stride,
+            dilation=dilation,
+            norm_func=FrozenBatchNorm2d,
+            dcn_config=dcn_config,
+            **kwargs
+        )
+
+
+class StemWithFixedBatchNorm(BaseStem):
+    def __init__(self, cfg):
+        super(StemWithFixedBatchNorm, self).__init__(
+            cfg, norm_func=FrozenBatchNorm2d
+        )
+
+
+class BottleneckWithBatchNorm(Bottleneck):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups=1,
+        stride_in_1x1=True,
+        stride=1,
+        dilation=1,
+        dcn_config=None,
+        **kwargs
+    ):
+        super(BottleneckWithBatchNorm, self).__init__(
+            in_channels=in_channels,
+            bottleneck_channels=bottleneck_channels,
+            out_channels=out_channels,
+            num_groups=num_groups,
+            stride_in_1x1=stride_in_1x1,
+            stride=stride,
+            dilation=dilation,
+            norm_func=BatchNorm2d,
+            dcn_config=dcn_config,
+            **kwargs
+        )
+
+
+class StemWithBatchNorm(BaseStem):
+    def __init__(self, cfg):
+        super(StemWithBatchNorm, self).__init__(
+            cfg, norm_func=BatchNorm2d
+        )
+
+
+class BottleneckWithNaiveSyncBatchNorm(Bottleneck):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups=1,
+        stride_in_1x1=True,
+        stride=1,
+        dilation=1,
+        dcn_config=None,
+        **kwargs
+    ):
+        super(BottleneckWithNaiveSyncBatchNorm, self).__init__(
+            in_channels=in_channels,
+            bottleneck_channels=bottleneck_channels,
+            out_channels=out_channels,
+            num_groups=num_groups,
+            stride_in_1x1=stride_in_1x1,
+            stride=stride,
+            dilation=dilation,
+            norm_func=NaiveSyncBatchNorm2d,
+            dcn_config=dcn_config,
+            **kwargs
+        )
+
+
+class StemWithNaiveSyncBatchNorm(BaseStem):
+    def __init__(self, cfg):
+        super(StemWithNaiveSyncBatchNorm, self).__init__(
+            cfg, norm_func=NaiveSyncBatchNorm2d
+        )
+
+
+class BottleneckWithSyncBatchNorm(Bottleneck):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups=1,
+        stride_in_1x1=True,
+        stride=1,
+        dilation=1,
+        dcn_config=None,
+        **kwargs
+    ):
+        super(BottleneckWithSyncBatchNorm, self).__init__(
+            in_channels=in_channels,
+            bottleneck_channels=bottleneck_channels,
+            out_channels=out_channels,
+            num_groups=num_groups,
+            stride_in_1x1=stride_in_1x1,
+            stride=stride,
+            dilation=dilation,
+            norm_func=SyncBatchNorm,
+            dcn_config=dcn_config,
+            **kwargs
+        )
+
+
+class StemWithSyncBatchNorm(BaseStem):
+    def __init__(self, cfg):
+        super(StemWithSyncBatchNorm, self).__init__(
+            cfg, norm_func=SyncBatchNorm
+        )
+
+
+class BottleneckWithGN(Bottleneck):
+    def __init__(
+        self,
+        in_channels,
+        bottleneck_channels,
+        out_channels,
+        num_groups=1,
+        stride_in_1x1=True,
+        stride=1,
+        dilation=1,
+        dcn_config=None,
+        **kwargs
+    ):
+        super(BottleneckWithGN, self).__init__(
+            in_channels=in_channels,
+            bottleneck_channels=bottleneck_channels,
+            out_channels=out_channels,
+            num_groups=num_groups,
+            stride_in_1x1=stride_in_1x1,
+            stride=stride,
+            dilation=dilation,
+            norm_func=group_norm,
+            dcn_config=dcn_config,
+            **kwargs
+        )
+
+
+class StemWithGN(BaseStem):
+    def __init__(self, cfg):
+        super(StemWithGN, self).__init__(cfg, norm_func=group_norm)
+
+
+_TRANSFORMATION_MODULES = Registry({
+    "BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm,
+    "BottleneckWithGN": BottleneckWithGN,
+})
+
+_STEM_MODULES = Registry({
+    "StemWithFixedBatchNorm": StemWithFixedBatchNorm,
+    "StemWithGN": StemWithGN,
+})
+
+_STAGE_SPECS = Registry({
+    "R-50-C4": ResNet50StagesTo4,
+    "R-50-C5": ResNet50StagesTo5,
+    "R-50-RETINANET": ResNet50StagesTo5,
+    "R-101-C4": ResNet101StagesTo4,
+    "R-101-C5": ResNet101StagesTo5,
+    "R-101-RETINANET": ResNet101StagesTo5,
+    "R-50-FPN": ResNet50FPNStagesTo5,
+    "R-50-FPN-RETINANET": ResNet50FPNStagesTo5,
+    "R-50-FPN-FCOS": ResNet50FPNStagesTo5,
+    "R-101-FPN": ResNet101FPNStagesTo5,
+    "R-101-FPN-RETINANET": ResNet101FPNStagesTo5,
+    "R-101-FPN-FCOS": ResNet101FPNStagesTo5,
+    "R-152-FPN": ResNet152FPNStagesTo5,
+})
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/swint.py b/maskrcnn_benchmark/modeling/backbone/swint.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0a162b6d28f71837a8812b3d3dfb9526451df74
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/swint.py
@@ -0,0 +1,650 @@
+# --------------------------------------------------------
+# Swin Transformer
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+class Mlp(nn.Module):
+    """ Multilayer perceptron."""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SwinTransformerBlock(nn.Module):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        self.H = None
+        self.W = None
+
+    def forward(self, x, mask_matrix):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    """ Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, attn_mask)
+            else:
+                x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        """Forward function."""
+        # padding
+        _, _, H, W = x.size()
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+        x = self.proj(x)  # B C Wh Ww
+        if self.norm is not None:
+            Wh, Ww = x.size(2), x.size(3)
+            x = x.flatten(2).transpose(1, 2)
+            x = self.norm(x)
+            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """ Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 frozen_stages=-1,
+                 use_checkpoint=False,
+                 out_features=["stage2", "stage3", "stage4", "stage5"],
+                 backbone_arch="SWINT-FPN-RETINANET"):
+        super(SwinTransformer, self).__init__()
+
+        print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint)
+        
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.frozen_stages = frozen_stages
+
+        self.out_features = out_features
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        self._out_feature_strides = {}
+        self._out_feature_channels = {}
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1)
+            self.layers.append(layer)
+
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
+                self._out_feature_strides[stage] = 4 * 2 ** i_layer
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in range(self.num_layers):
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                if i_layer == 0 and backbone_arch.endswith("RETINANET"):
+                    layer = nn.Identity()
+                else:
+                    layer = norm_layer(num_features[i_layer])
+                layer_name = f'norm{i_layer}'
+                self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        self.apply(_init_weights)
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+            name = f'stage{i + 2}'
+            if name in self.out_features:
+                norm_layer = getattr(self, f'norm{i}')
+                x_out = norm_layer(x_out)
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs.append(out)
+
+        return outs
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+def build_swint_backbone(cfg):
+    """
+    Create a SwinT instance from config.
+
+    Returns:
+        VoVNet: a :class:`VoVNet` instance.
+    """
+    return SwinTransformer(
+        patch_size=4,
+        in_chans=3,
+        embed_dim=cfg.MODEL.SWINT.EMBED_DIM,
+        depths=cfg.MODEL.SWINT.DEPTHS,
+        num_heads=cfg.MODEL.SWINT.NUM_HEADS,
+        window_size=cfg.MODEL.SWINT.WINDOW_SIZE,
+        mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE,
+        norm_layer=nn.LayerNorm,
+        ape=cfg.MODEL.SWINT.APE,
+        patch_norm=True,
+        frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT,
+        backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY,
+        use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT,
+        out_features=cfg.MODEL.BACKBONE.OUT_FEATURES
+    )
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/swint_v2.py b/maskrcnn_benchmark/modeling/backbone/swint_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1fe550c33f06880f67176cd03c918db06c25e6f
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/swint_v2.py
@@ -0,0 +1,734 @@
+# --------------------------------------------------------
+# Swin Transformer
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from einops import rearrange
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+class Mlp(nn.Module):
+    """ Multilayer perceptron."""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SwinTransformerBlock(nn.Module):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        self.H = None
+        self.W = None
+
+        self.gamma = 1.0
+        if layer_scale:
+            self.gamma = nn.Parameter(
+                1e-4*torch.ones(dim), requires_grad=True
+            )
+
+
+    def forward(self, x, mask_matrix):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(self.gamma*x)
+        x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    """ Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 layer_scale=False):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer,
+                layer_scale=layer_scale)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(patch_size=3, in_chans=dim, embed_dim=dim*2,
+                                         stride=2, padding=1, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, attn_mask)
+            else:
+                x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+# class PatchEmbed(nn.Module):
+#     """ Image to Patch Embedding
+#     Args:
+#         patch_size (int): Patch token size. Default: 4.
+#         in_chans (int): Number of input image channels. Default: 3.
+#         embed_dim (int): Number of linear projection output channels. Default: 96.
+#         norm_layer (nn.Module, optional): Normalization layer. Default: None
+#     """
+#
+#     def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+#         super().__init__()
+#         patch_size = to_2tuple(patch_size)
+#         self.patch_size = patch_size
+#
+#         self.in_chans = in_chans
+#         self.embed_dim = embed_dim
+#
+#         self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+#         if norm_layer is not None:
+#             self.norm = norm_layer(embed_dim)
+#         else:
+#             self.norm = None
+#
+#     def forward(self, x):
+#         """Forward function."""
+#         # padding
+#         _, _, H, W = x.size()
+#         if W % self.patch_size[1] != 0:
+#             x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+#         if H % self.patch_size[0] != 0:
+#             x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+#
+#         x = self.proj(x)  # B C Wh Ww
+#         if self.norm is not None:
+#             Wh, Ww = x.size(2), x.size(3)
+#             x = x.flatten(2).transpose(1, 2)
+#             x = self.norm(x)
+#             x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+#
+#         return x
+
+
+class ConvEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(
+        self,
+        patch_size=7,
+        in_chans=3,
+        embed_dim=64,
+        stride=4,
+        padding=2,
+        norm_layer=None
+    ):
+        super().__init__()
+        self.patch_size = patch_size
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim,
+            kernel_size=patch_size,
+            stride=stride,
+            padding=padding
+        )
+        self.norm = norm_layer(embed_dim) if norm_layer else None
+
+    def forward(self, x, H=None, W=None):
+        restore_hw = False
+        if H is None and W is None and len(x.size()) == 4:
+            _, _, H, W = x.size()
+            if W % self.patch_size != 0:
+                x = F.pad(x, (0, self.patch_size - W % self.patch_size))
+            if H % self.patch_size != 0:
+                x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))
+            restore_hw = True
+
+        if len(x.size()) == 3:
+            x = rearrange(
+                x, 'b (h w) c -> b c h w',
+                h=H,
+                w=W
+            )
+        x = self.proj(x)  # B C Wh Ww
+        B, C, Wh, Ww = x.shape
+        x = rearrange(x, 'b c h w -> b (h w) c')
+        if self.norm:
+            x = self.norm(x)
+
+        if restore_hw:
+            x = rearrange(
+                x, 'b (h w) c -> b c h w',
+                h=Wh,
+                w=Ww
+            )
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """ Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=7,
+                 patch_padding=2,
+                 patch_stride=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 frozen_stages=-1,
+                 use_checkpoint=False,
+                 layer_scale=False,
+                 out_features=["stage2", "stage3", "stage4", "stage5"],
+                 out_norm=True,
+                 backbone_arch="SWINT-FPN-RETINANET"):
+        super(SwinTransformer, self).__init__()
+
+        print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint)
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.frozen_stages = frozen_stages
+
+        self.out_features = out_features
+        self.out_norm = out_norm
+
+        # split image into non-overlapping patches
+        # self.patch_embed = PatchEmbed(
+        #     patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+        #     norm_layer=norm_layer if self.patch_norm else None)
+        self.patch_embed = ConvEmbed(
+            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding,
+            norm_layer=norm_layer if self.patch_norm else None
+        )
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        self._out_feature_strides = {}
+        self._out_feature_channels = {}
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1,
+                layer_scale=layer_scale)
+            self.layers.append(layer)
+
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
+                self._out_feature_strides[stage] = 4 * 2 ** i_layer
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        if self.out_norm:
+            for i_layer in range(self.num_layers):
+                stage = f'stage{i_layer + 2}'
+                if stage in self.out_features:
+                    if i_layer == 0 and backbone_arch.endswith("RETINANET"):
+                        layer = nn.Identity()
+                    else:
+                        layer = norm_layer(num_features[i_layer])
+                    layer_name = f'norm{i_layer}'
+                    self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        self.apply(_init_weights)
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+            name = f'stage{i + 2}'
+            if name in self.out_features:
+                if self.out_norm:
+                    norm_layer = getattr(self, f'norm{i}')
+                    x_out = norm_layer(x_out)
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs.append(out)
+
+        return outs
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+def build_swint_backbone(cfg):
+    """
+    Create a SwinT instance from config.
+
+    Returns:
+        VoVNet: a :class:`VoVNet` instance.
+    """
+    return SwinTransformer(
+        patch_size=7,
+        patch_padding=2,
+        patch_stride=4,
+        in_chans=3,
+        embed_dim=cfg.MODEL.SWINT.EMBED_DIM,
+        depths=cfg.MODEL.SWINT.DEPTHS,
+        num_heads=cfg.MODEL.SWINT.NUM_HEADS,
+        window_size=cfg.MODEL.SWINT.WINDOW_SIZE,
+        mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE,
+        norm_layer=nn.LayerNorm,
+        ape=cfg.MODEL.SWINT.APE,
+        patch_norm=True,
+        frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT,
+        backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY,
+        use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT,
+        layer_scale=cfg.MODEL.SWINT.LAYER_SCALE,
+        out_features=cfg.MODEL.BACKBONE.OUT_FEATURES,
+        out_norm=cfg.MODEL.SWINT.OUT_NORM,
+    )
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py b/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..008fda21f1d5c82661146e30a3ff3496579035e3
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py
@@ -0,0 +1,861 @@
+# --------------------------------------------------------
+# Swin Transformer
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from einops import rearrange
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+    """ Multilayer perceptron."""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
+                 ntext=None, dim_text=None):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+        if ntext is not None:
+            self.qkv_text = nn.Linear(dim_text, dim * 3, bias=qkv_bias)
+            self.proj_text = nn.Linear(dim, dim_text)
+
+            self.i2t_relative_position_bias = nn.Parameter(
+                torch.zeros(2, num_heads, ntext))  # (2, nH, ntext)
+            self.t2t_relative_position_bias = nn.Parameter(
+                torch.zeros(num_heads, ntext, ntext))  # (nH, ntext, ntext)
+            trunc_normal_(self.i2t_relative_position_bias, std=.02)
+            trunc_normal_(self.t2t_relative_position_bias, std=.02)
+
+    def forward(self, x, mask=None, x_text=None, mask_text=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+            x_text: input text features with shape of (B_text, N_text, C_text)
+            mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; TODO: support casual mask
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+
+        if x_text is not None:
+            B_text, N_text, C_text = x_text.shape
+            nW = B_ // B_text  # number of windows
+            assert B_text * nW == B_, "B_ is not a multiplier of B_text in window attention"
+            # notice that after qkv_text, the hidden dimension is C instead of C_text
+            qkv_text = self.qkv_text(x_text).reshape(B_text, N_text, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3,
+                                                                                                                1, 4)
+            q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[
+                2]  # make torchscript happy (cannot use tensor as tuple)
+
+            # image to text attention
+            attn_i2t = (q @ torch.repeat_interleave(k_text, nW, dim=0).transpose(-2, -1))  # B_, nH, N, N_text
+            # add image to text bias and text_mask
+            if mask_text is not None:
+                mask_and_i2t_bias = mask_text.view(B_text, 1, 1, N_text) + self.i2t_relative_position_bias[:1].expand(
+                    B_text, -1, -1).unsqueeze(-2)  # B_text, nH, 1, N_text
+            else:
+                mask_and_i2t_bias = self.i2t_relative_position_bias[:1].expand(B_text, -1, -1).unsqueeze(
+                    -2)  # B_text, nH, 1, N_text
+            attn_i2t = attn_i2t + torch.repeat_interleave(mask_and_i2t_bias, nW, dim=0)
+
+            attn = torch.cat((attn, attn_i2t), dim=-1)  # B_, nH, N, N+N_text
+
+        attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        if x_text is None:
+            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        else:
+            x = (
+                    attn @ torch.cat((v, torch.repeat_interleave(v_text, nW, dim=0)), dim=-2)
+            ).transpose(1, 2).reshape(B_, N, C)
+
+            # compute attn_t2i
+            q_text = q_text * self.scale
+
+            kv = qkv[1:].reshape(2, B_text, nW, self.num_heads, N, C // self.num_heads).transpose(2, 3)
+            k, v = kv[0].reshape(B_text, self.num_heads, nW * N, -1), kv[1].reshape(B_text, self.num_heads, nW * N, -1)
+            attn_t2i = (q_text @ k.transpose(-2, -1))
+            mask_t2i = self.i2t_relative_position_bias[1:].expand(B_text, -1, -1).unsqueeze(-1)  # B_text, nH, N_text, 1
+            attn_t2i = attn_t2i + mask_t2i
+
+            attn_t2t = (q_text @ k_text.transpose(-2, -1))
+            # add relative positional bias
+            attn_t2t = attn_t2t + self.t2t_relative_position_bias.unsqueeze(0)
+            if mask_text is not None:
+                attn_t2t = attn_t2t + mask_text.view(B_text, 1, 1, N_text)
+
+            attn_t = torch.cat((attn_t2i, attn_t2t), dim=-1)  # B_text, nH, N_text, N+N_text
+            attn_t = self.softmax(attn_t)
+            attn_t = self.attn_drop(attn_t)
+
+            x_text = (
+                    attn_t @ torch.cat((v, v_text), dim=-2)
+            ).transpose(1, 2).reshape(B_text, N_text, C)
+
+            x_text = self.proj_text(x_text)
+            x_text = self.proj_drop(x_text)
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x, x_text
+
+
+class SwinTransformerBlock(nn.Module):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False, ntext=None, dim_text=None):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
+            ntext=ntext, dim_text=dim_text
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        self.H = None
+        self.W = None
+
+        self.gamma = 1.0
+        if layer_scale:
+            self.gamma = nn.Parameter(
+                1e-4*torch.ones(dim), requires_grad=True
+            )
+
+        if dim_text is not None:
+            self.norm1_text = norm_layer(dim_text)
+            self.norm2_text = norm_layer(dim_text)
+            mlp_hidden_dim_text = int(dim_text * mlp_ratio)
+            self.mlp_text = Mlp(in_features=dim_text, hidden_features=mlp_hidden_dim_text, act_layer=act_layer,
+                                drop=drop)
+            self.gamma_text = 1.0
+            if layer_scale:
+                self.gamma_text = nn.Parameter(
+                    1e-4*torch.ones(dim_text), requires_grad=True
+                )
+
+    def forward(self, x, mask_matrix, x_text, mask_text):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+            x_text: Input text feature, tensor size (B, L_text, C_text). L_text: Number of text tokens.
+            mask_text: text mask (vector right now).
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        if x_text is not None:
+            B, L_text, C_text = x_text.shape
+            shortcut_text = x_text
+            x_text = self.norm1_text(x_text)
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows, x_text = self.attn(x_windows, mask=attn_mask, x_text=x_text,
+                                         mask_text=mask_text)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(self.gamma*x)
+        x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x)))
+
+        if x_text is not None:
+            x_text = shortcut_text + self.drop_path(self.gamma_text*x_text)
+            x_text = x_text + self.drop_path(self.gamma_text*self.mlp_text(self.norm2_text(x_text)))
+
+        return x, x_text
+
+
+class PatchMerging(nn.Module):
+    """ Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 layer_scale=False,
+                 ntext=None,
+                 dim_text=None):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer,
+                layer_scale=layer_scale,
+                ntext=ntext,
+                dim_text=dim_text)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(patch_size=3, in_chans=dim, embed_dim=dim*2,
+                                         stride=2, padding=1, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W, x_text=None, mask_text=None):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            x_text: input text features with shape of (B_text, N_text, C_text)
+            mask_text: (0/-inf) mask with shape of (B_text, N_text) or None;
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x, x_text = checkpoint.checkpoint(blk, x, attn_mask, x_text, mask_text)
+            else:
+                x, x_text = blk(x, attn_mask, x_text, mask_text)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww, x_text
+        else:
+            return x, H, W, x, H, W, x_text
+
+
+# class PatchEmbed(nn.Module):
+#     """ Image to Patch Embedding
+#     Args:
+#         patch_size (int): Patch token size. Default: 4.
+#         in_chans (int): Number of input image channels. Default: 3.
+#         embed_dim (int): Number of linear projection output channels. Default: 96.
+#         norm_layer (nn.Module, optional): Normalization layer. Default: None
+#     """
+#
+#     def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+#         super().__init__()
+#         patch_size = to_2tuple(patch_size)
+#         self.patch_size = patch_size
+#
+#         self.in_chans = in_chans
+#         self.embed_dim = embed_dim
+#
+#         self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+#         if norm_layer is not None:
+#             self.norm = norm_layer(embed_dim)
+#         else:
+#             self.norm = None
+#
+#     def forward(self, x):
+#         """Forward function."""
+#         # padding
+#         _, _, H, W = x.size()
+#         if W % self.patch_size[1] != 0:
+#             x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+#         if H % self.patch_size[0] != 0:
+#             x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+#
+#         x = self.proj(x)  # B C Wh Ww
+#         if self.norm is not None:
+#             Wh, Ww = x.size(2), x.size(3)
+#             x = x.flatten(2).transpose(1, 2)
+#             x = self.norm(x)
+#             x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+#
+#         return x
+
+
+class ConvEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(
+        self,
+        patch_size=7,
+        in_chans=3,
+        embed_dim=64,
+        stride=4,
+        padding=2,
+        norm_layer=None
+    ):
+        super().__init__()
+        self.patch_size = patch_size
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim,
+            kernel_size=patch_size,
+            stride=stride,
+            padding=padding
+        )
+        self.norm = norm_layer(embed_dim) if norm_layer else None
+
+    def forward(self, x, H=None, W=None):
+        restore_hw = False
+        if H is None and W is None and len(x.size()) == 4:
+            _, _, H, W = x.size()
+            if W % self.patch_size != 0:
+                x = F.pad(x, (0, self.patch_size - W % self.patch_size))
+            if H % self.patch_size != 0:
+                x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))
+            restore_hw = True
+
+        if len(x.size()) == 3:
+            x = rearrange(
+                x, 'b (h w) c -> b c h w',
+                h=H,
+                w=W
+            )
+        x = self.proj(x)  # B C Wh Ww
+        B, C, Wh, Ww = x.shape
+        x = rearrange(x, 'b c h w -> b (h w) c')
+        if self.norm:
+            x = self.norm(x)
+
+        if restore_hw:
+            x = rearrange(
+                x, 'b (h w) c -> b c h w',
+                h=Wh,
+                w=Ww
+            )
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """ Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=7,
+                 patch_padding=2,
+                 patch_stride=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 frozen_stages=-1,
+                 use_checkpoint=False,
+                 layer_scale=False,
+                 out_features=["stage2", "stage3", "stage4", "stage5"],
+                 out_norm=True,
+                 backbone_arch="SWINT-FPN-RETINANET",
+                 max_query_len=None,
+                 lang_dim=None):
+        super(SwinTransformer, self).__init__()
+
+        print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint)
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.frozen_stages = frozen_stages
+
+        self.out_features = out_features
+        self.out_norm = out_norm
+
+        # split image into non-overlapping patches
+        # self.patch_embed = PatchEmbed(
+        #     patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+        #     norm_layer=norm_layer if self.patch_norm else None)
+        self.patch_embed = ConvEmbed(
+            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding,
+            norm_layer=norm_layer if self.patch_norm else None
+        )
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        self._out_feature_strides = {}
+        self._out_feature_channels = {}
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            if i_layer < self.num_layers - 1:
+                ntext, dim_text = None, None
+            else:
+                ntext, dim_text = max_query_len, lang_dim
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1,
+                layer_scale=layer_scale,
+                ntext=ntext,
+                dim_text=dim_text
+            )
+            self.layers.append(layer)
+
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
+                self._out_feature_strides[stage] = 4 * 2 ** i_layer
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        if self.out_norm:
+            for i_layer in range(self.num_layers):
+                stage = f'stage{i_layer + 2}'
+                if stage in self.out_features:
+                    if i_layer == 0 and backbone_arch.endswith("RETINANET"):
+                        layer = nn.Identity()
+                    else:
+                        layer = norm_layer(num_features[i_layer])
+                    layer_name = f'norm{i_layer}'
+                    self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        self.apply(_init_weights)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = inputs["img"]
+        language_dict_features = inputs["lang"]
+
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        x_text = language_dict_features['hidden']
+        if "masks" in language_dict_features:
+            mask_text = 1.0 - language_dict_features["masks"]    # (B, N_text) 0 means not to be masked out
+            mask_text.masked_fill_(mask_text.bool(), -float('inf'))
+        else:
+            mask_text = None
+
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            if i < self.num_layers - 1:
+                x_out, H, W, x, Wh, Ww, _ = layer(x, Wh, Ww, x_text=None, mask_text=None)
+            else:
+                x_out, H, W, x, Wh, Ww, x_text = layer(x, Wh, Ww, x_text=x_text, mask_text=mask_text)
+            name = f'stage{i + 2}'
+            if name in self.out_features:
+                if self.out_norm:
+                    norm_layer = getattr(self, f'norm{i}')
+                    x_out = norm_layer(x_out)
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs.append(out)
+
+        # the backbone only update the "hidden" field, currently
+        language_dict_features['hidden'] = x_text
+
+        return outs, language_dict_features
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+def build_swint_backbone(cfg):
+    """
+    Create a SwinT instance from config.
+
+    Returns:
+        VoVNet: a :class:`VoVNet` instance.
+    """
+    return SwinTransformer(
+        patch_size=7,
+        patch_padding=2,
+        patch_stride=4,
+        in_chans=3,
+        embed_dim=cfg.MODEL.SWINT.EMBED_DIM,
+        depths=cfg.MODEL.SWINT.DEPTHS,
+        num_heads=cfg.MODEL.SWINT.NUM_HEADS,
+        window_size=cfg.MODEL.SWINT.WINDOW_SIZE,
+        mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE,
+        norm_layer=nn.LayerNorm,
+        ape=cfg.MODEL.SWINT.APE,
+        patch_norm=True,
+        frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT,
+        backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY,
+        use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT,
+        layer_scale=cfg.MODEL.SWINT.LAYER_SCALE,
+        out_features=cfg.MODEL.BACKBONE.OUT_FEATURES,
+        out_norm=cfg.MODEL.SWINT.OUT_NORM,
+        max_query_len=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+        lang_dim=cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
+    )
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/backbone/swint_vl.py b/maskrcnn_benchmark/modeling/backbone/swint_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..97ed5705f727c26f0a5bbb21e95050d39a5348da
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/backbone/swint_vl.py
@@ -0,0 +1,774 @@
+# --------------------------------------------------------
+# Swin Transformer
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+    """ Multilayer perceptron."""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
+                 ntext=None, dim_text=None):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+        if ntext is not None:
+            self.qkv_text = nn.Linear(dim_text, dim * 3, bias=qkv_bias)
+            self.proj_text = nn.Linear(dim, dim_text)
+
+            self.i2t_relative_position_bias = nn.Parameter(
+                torch.zeros(2, num_heads, ntext))  # (2, nH, ntext)
+            self.t2t_relative_position_bias = nn.Parameter(
+                torch.zeros(num_heads, ntext, ntext))  # (nH, ntext, ntext)
+            trunc_normal_(self.i2t_relative_position_bias, std=.02)
+            trunc_normal_(self.t2t_relative_position_bias, std=.02)
+
+    def forward(self, x, mask=None, x_text=None, mask_text=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+            x_text: input text features with shape of (B_text, N_text, C_text)
+            mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; TODO: support casual mask
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+
+        if x_text is not None:
+            B_text, N_text, C_text = x_text.shape
+            nW = B_ // B_text  # number of windows
+            assert B_text * nW == B_, "B_ is not a multiplier of B_text in window attention"
+            # notice that after qkv_text, the hidden dimension is C instead of C_text
+            qkv_text = self.qkv_text(x_text).reshape(B_text, N_text, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3,
+                                                                                                                1, 4)
+            q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[
+                2]  # make torchscript happy (cannot use tensor as tuple)
+
+            # image to text attention
+            attn_i2t = (q @ torch.repeat_interleave(k_text, nW, dim=0).transpose(-2, -1))  # B_, nH, N, N_text
+            # add image to text bias and text_mask
+            if mask_text is not None:
+                mask_and_i2t_bias = mask_text.view(B_text, 1, 1, N_text) + self.i2t_relative_position_bias[:1].expand(
+                    B_text, -1, -1).unsqueeze(-2)  # B_text, nH, 1, N_text
+            else:
+                mask_and_i2t_bias = self.i2t_relative_position_bias[:1].expand(B_text, -1, -1).unsqueeze(
+                    -2)  # B_text, nH, 1, N_text
+            attn_i2t = attn_i2t + torch.repeat_interleave(mask_and_i2t_bias, nW, dim=0)
+
+            attn = torch.cat((attn, attn_i2t), dim=-1)  # B_, nH, N, N+N_text
+
+        attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        if x_text is None:
+            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        else:
+            x = (
+                    attn @ torch.cat((v, torch.repeat_interleave(v_text, nW, dim=0)), dim=-2)
+            ).transpose(1, 2).reshape(B_, N, C)
+
+            # compute attn_t2i
+            q_text = q_text * self.scale
+
+            kv = qkv[1:].reshape(2, B_text, nW, self.num_heads, N, C // self.num_heads).transpose(2, 3)
+            k, v = kv[0].reshape(B_text, self.num_heads, nW * N, -1), kv[1].reshape(B_text, self.num_heads, nW * N, -1)
+            attn_t2i = (q_text @ k.transpose(-2, -1))
+            mask_t2i = self.i2t_relative_position_bias[1:].expand(B_text, -1, -1).unsqueeze(-1)  # B_text, nH, N_text, 1
+            attn_t2i = attn_t2i + mask_t2i
+
+            attn_t2t = (q_text @ k_text.transpose(-2, -1))
+            # add relative positional bias
+            attn_t2t = attn_t2t + self.t2t_relative_position_bias.unsqueeze(0)
+            if mask_text is not None:
+                attn_t2t = attn_t2t + mask_text.view(B_text, 1, 1, N_text)
+
+            attn_t = torch.cat((attn_t2i, attn_t2t), dim=-1)  # B_text, nH, N_text, N+N_text
+            attn_t = self.softmax(attn_t)
+            attn_t = self.attn_drop(attn_t)
+
+            x_text = (
+                    attn_t @ torch.cat((v, v_text), dim=-2)
+            ).transpose(1, 2).reshape(B_text, N_text, C)
+
+            x_text = self.proj_text(x_text)
+            x_text = self.proj_drop(x_text)
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x, x_text
+
+
+class SwinTransformerBlock(nn.Module):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, ntext=None, dim_text=None):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
+            ntext=ntext, dim_text=dim_text
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        self.H = None
+        self.W = None
+
+        if dim_text is not None:
+            self.norm1_text = norm_layer(dim_text)
+            self.norm2_text = norm_layer(dim_text)
+            mlp_hidden_dim_text = int(dim_text * mlp_ratio)
+            self.mlp_text = Mlp(in_features=dim_text, hidden_features=mlp_hidden_dim_text, act_layer=act_layer,
+                                drop=drop)
+
+    def forward(self, x, mask_matrix, x_text, mask_text):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+            x_text: Input text feature, tensor size (B, L_text, C_text). L_text: Number of text tokens.
+            mask_text: text mask (vector right now).
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        if x_text is not None:
+            B, L_text, C_text = x_text.shape
+            shortcut_text = x_text
+            x_text = self.norm1_text(x_text)
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows, x_text = self.attn(x_windows, mask=attn_mask, x_text=x_text,
+                                         mask_text=mask_text)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        if x_text is not None:
+            x_text = shortcut_text + self.drop_path(x_text)
+            x_text = x_text + self.drop_path(self.mlp_text(self.norm2_text(x_text)))
+
+        return x, x_text
+
+
+class PatchMerging(nn.Module):
+    """ Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 ntext=None,
+                 dim_text=None):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer,
+                ntext=ntext,
+                dim_text=dim_text)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W, x_text=None, mask_text=None):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            x_text: input text features with shape of (B_text, N_text, C_text)
+            mask_text: (0/-inf) mask with shape of (B_text, N_text) or None;
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x, x_text = checkpoint.checkpoint(blk, x, attn_mask, x_text, mask_text)
+            else:
+                x, x_text = blk(x, attn_mask, x_text, mask_text)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww, x_text
+        else:
+            return x, H, W, x, H, W, x_text
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        """Forward function."""
+        # padding
+        _, _, H, W = x.size()
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+        x = self.proj(x)  # B C Wh Ww
+        if self.norm is not None:
+            Wh, Ww = x.size(2), x.size(3)
+            x = x.flatten(2).transpose(1, 2)
+            x = self.norm(x)
+            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """ Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 frozen_stages=-1,
+                 use_checkpoint=False,
+                 out_features=["stage2", "stage3", "stage4", "stage5"],
+                 backbone_arch="SWINT-FPN-RETINANET",
+                 max_query_len=None,
+                 lang_dim=None):
+        super(SwinTransformer, self).__init__()
+
+        print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint)
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.frozen_stages = frozen_stages
+
+        self.out_features = out_features
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        self._out_feature_strides = {}
+        self._out_feature_channels = {}
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            if i_layer < self.num_layers - 1:
+                ntext, dim_text = None, None
+            else:
+                ntext, dim_text = max_query_len, lang_dim
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1,
+                ntext=ntext,
+                dim_text=dim_text
+            )
+            self.layers.append(layer)
+
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
+                self._out_feature_strides[stage] = 4 * 2 ** i_layer
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in range(self.num_layers):
+            stage = f'stage{i_layer + 2}'
+            if stage in self.out_features:
+                if i_layer == 0 and backbone_arch.endswith("RETINANET"):
+                    layer = nn.Identity()
+                else:
+                    layer = norm_layer(num_features[i_layer])
+                layer_name = f'norm{i_layer}'
+                self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        self.apply(_init_weights)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = inputs["img"]
+        language_dict_features = inputs["lang"]
+
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        x_text = language_dict_features['hidden']
+        if "masks" in language_dict_features:
+            mask_text = 1.0 - language_dict_features["masks"]    # (B, N_text) 0 means not to be masked out
+            mask_text.masked_fill_(mask_text.bool(), -float('inf'))
+        else:
+            mask_text = None
+
+
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            if i < self.num_layers - 1:
+                x_out, H, W, x, Wh, Ww, _ = layer(x, Wh, Ww, x_text=None, mask_text=None)
+            else:
+                x_out, H, W, x, Wh, Ww, x_text = layer(x, Wh, Ww, x_text=x_text, mask_text=mask_text)
+            name = f'stage{i + 2}'
+            if name in self.out_features:
+                norm_layer = getattr(self, f'norm{i}')
+                x_out = norm_layer(x_out)
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs.append(out)
+
+        # the backbone only update the "hidden" field, currently
+        language_dict_features['hidden'] = x_text
+
+        return outs, language_dict_features
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+def build_swint_backbone(cfg):
+    """
+    Create a SwinT instance from config.
+
+    Returns:
+        VoVNet: a :class:`VoVNet` instance.
+    """
+    return SwinTransformer(
+        patch_size=4,
+        in_chans=3,
+        embed_dim=cfg.MODEL.SWINT.EMBED_DIM,
+        depths=cfg.MODEL.SWINT.DEPTHS,
+        num_heads=cfg.MODEL.SWINT.NUM_HEADS,
+        window_size=cfg.MODEL.SWINT.WINDOW_SIZE,
+        mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE,
+        norm_layer=nn.LayerNorm,
+        ape=cfg.MODEL.SWINT.APE,
+        patch_norm=True,
+        frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT,
+        backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY,
+        use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT,
+        out_features=cfg.MODEL.BACKBONE.OUT_FEATURES,
+        max_query_len=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+        lang_dim=cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
+    )
diff --git a/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..388c7ea8720a77bdc93718754798fcdeb43f6383
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py
@@ -0,0 +1,68 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+
+class BalancedPositiveNegativeSampler(object):
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image, positive_fraction):
+        """
+        Arguments:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentace of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs):
+        """
+        Arguments:
+            matched idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
+            negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(
+                matched_idxs_per_image, dtype=torch.bool
+            )
+            neg_idx_per_image_mask = torch.zeros_like(
+                matched_idxs_per_image, dtype=torch.bool
+            )
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
diff --git a/maskrcnn_benchmark/modeling/box_coder.py b/maskrcnn_benchmark/modeling/box_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca39db1aa954be3482259797706ca12e56a77f1
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/box_coder.py
@@ -0,0 +1,95 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import math
+
+import torch
+
+
+class BoxCoder(object):
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
+        """
+        Arguments:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes, proposals):
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Arguments:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+
+        TO_REMOVE = 1  # TODO remove
+        ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
+        ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
+        ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
+        ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
+
+        gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
+        gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
+        gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
+        gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
+
+        wx, wy, ww, wh = self.weights
+        targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+        targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+        targets_dw = ww * torch.log(gt_widths / ex_widths)
+        targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+        return targets
+
+    def decode(self, rel_codes, boxes):
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Arguments:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        TO_REMOVE = 1  # TODO remove
+        widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
+        heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        pred_boxes = torch.zeros_like(rel_codes)
+        # x1
+        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
+        # y1
+        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
+        # x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
+        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
+        # y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
+        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
+
+        return pred_boxes
diff --git a/maskrcnn_benchmark/modeling/detector/__init__.py b/maskrcnn_benchmark/modeling/detector/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf35abda3c725249b875078a85d76ede727f9d19
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/detector/__init__.py
@@ -0,0 +1,11 @@
+from .generalized_rcnn import GeneralizedRCNN
+from .generalized_vl_rcnn import GeneralizedVLRCNN
+
+_DETECTION_META_ARCHITECTURES = {"GeneralizedRCNN": GeneralizedRCNN,
+                                 "GeneralizedVLRCNN": GeneralizedVLRCNN
+                                 }
+
+
+def build_detection_model(cfg):
+    meta_arch = _DETECTION_META_ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE]
+    return meta_arch(cfg)
diff --git a/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7307722d5e281e52a73bff6fb76706445c11e810
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py
@@ -0,0 +1,124 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Implements the Generalized R-CNN framework
+"""
+
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.structures.image_list import to_image_list
+
+from ..backbone import build_backbone
+from ..rpn import build_rpn
+from ..roi_heads import build_roi_heads
+
+import timeit
+
+class GeneralizedRCNN(nn.Module):
+    """
+    Main class for Generalized R-CNN. Currently supports boxes and masks.
+    It consists of three main parts:
+    - backbone
+    - rpn
+    - heads: takes the features + the proposals from the RPN and computes
+        detections / masks from it.
+    """
+
+    def __init__(self, cfg):
+        super(GeneralizedRCNN, self).__init__()
+
+        self.backbone = build_backbone(cfg)
+        self.rpn = build_rpn(cfg)
+        self.roi_heads = build_roi_heads(cfg)
+        self.DEBUG = cfg.MODEL.DEBUG
+        self.ONNX = cfg.MODEL.ONNX
+        self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
+        self.freeze_fpn = cfg.MODEL.FPN.FREEZE
+        self.freeze_rpn = cfg.MODEL.RPN.FREEZE
+
+        if cfg.MODEL.LINEAR_PROB:
+            assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
+            if hasattr(self.backbone, 'fpn'):
+                assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
+        self.linear_prob = cfg.MODEL.LINEAR_PROB
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(GeneralizedRCNN, self).train(mode)
+        if self.freeze_backbone:
+            self.backbone.body.eval()
+            for p in self.backbone.body.parameters():
+                p.requires_grad = False
+        if self.freeze_fpn:
+            self.backbone.fpn.eval()
+            for p in self.backbone.fpn.parameters():
+                p.requires_grad = False
+        if self.freeze_rpn:
+            self.rpn.eval()
+            for p in self.rpn.parameters():
+                p.requires_grad = False
+        if self.linear_prob:
+            if self.rpn is not None:
+                for key, value in self.rpn.named_parameters():
+                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key):
+                        value.requires_grad = False
+            if self.roi_heads is not None:
+                for key, value in self.roi_heads.named_parameters():
+                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key):
+                        value.requires_grad = False
+
+    def forward(self, images, targets=None):
+        """
+        Arguments:
+            images (list[Tensor] or ImageList): images to be processed
+            targets (list[BoxList]): ground-truth boxes present in the image (optional)
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+        """
+        if self.training and targets is None:
+            raise ValueError("In training mode, targets should be passed")
+
+        if self.DEBUG: debug_info = {}
+        if self.DEBUG: debug_info['input_size'] = images[0].size()
+        if self.DEBUG: tic = timeit.time.perf_counter()
+
+        if self.ONNX:
+            features = self.backbone(images)
+        else:
+            images = to_image_list(images)
+            features = self.backbone(images.tensors)
+
+        if self.DEBUG: debug_info['feat_time'] = timeit.time.perf_counter() - tic
+        if self.DEBUG: debug_info['feat_size'] = [feat.size() for feat in features]
+        if self.DEBUG: tic = timeit.time.perf_counter()
+
+        proposals, proposal_losses = self.rpn(images, features, targets)
+
+        if self.DEBUG: debug_info['rpn_time'] = timeit.time.perf_counter() - tic
+        if self.DEBUG: debug_info['#rpn'] = [prop for prop in proposals]
+        if self.DEBUG: tic = timeit.time.perf_counter()
+
+        if self.roi_heads:
+            x, result, detector_losses = self.roi_heads(features, proposals, targets)
+        else:
+            # RPN-only models don't have roi_heads
+            x = features
+            result = proposals
+            detector_losses = {}
+
+        if self.DEBUG: debug_info['rcnn_time'] = timeit.time.perf_counter() - tic
+        if self.DEBUG: debug_info['#rcnn'] = result
+        if self.DEBUG: return result, debug_info
+
+        if self.training:
+            losses = {}
+            losses.update(detector_losses)
+            losses.update(proposal_losses)
+            return losses
+
+        return result
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py b/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..01f64c6ffc6272a222777cb4deb6b2ee3d715b23
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py
@@ -0,0 +1,466 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Implements the Generalized VL R-CNN framework
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from maskrcnn_benchmark.structures.image_list import to_image_list
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+
+from ..backbone import build_backbone
+from ..rpn import build_rpn
+from ..roi_heads import build_roi_heads
+
+from ..language_backbone import build_language_backbone
+from transformers import AutoTokenizer
+
+import random
+import timeit
+import pdb
+from copy import deepcopy
+
+def random_word(input_ids, mask_token_id, vocabs, padding_token_id, greenlight_map):
+    """
+    greenlight_map, batch_size x 256 (seq_len):
+        0 means this location cannot be calculated in the MLM loss
+        -1 means this location cannot be masked!!
+        1 means this location can be masked and can be calculated in the MLM loss
+    """
+    output_label = deepcopy(input_ids)
+    for j in range(input_ids.size(0)):
+        for i in range(input_ids.size(1)):
+            prob = random.random()
+            # mask token with probability
+            ratio = 0.15
+            if greenlight_map is not None and greenlight_map[j,i] == -1:
+                output_label[j,i] = -100
+                continue
+
+            if (not input_ids[j,i] == padding_token_id) and prob < ratio:
+                prob /= ratio
+
+                # 80% randomly change token to mask token
+                if prob < 0.8:
+                    input_ids[j,i] = mask_token_id
+
+                # 10% randomly change token to random token
+                elif prob < 0.9:
+                    input_ids[j,i] = random.choice(vocabs)
+
+            else:
+                # no masking token (will be ignored by loss function later)
+                output_label[j,i] = -100
+            
+            if greenlight_map is not None and greenlight_map[j,i] != 1:
+                output_label[j,i] = -100 # If this location should not be masked
+    return input_ids, output_label
+
+
+class GeneralizedVLRCNN(nn.Module):
+    """
+    Main class for Generalized R-CNN. Currently supports boxes and masks.
+    It consists of three main parts:
+    - backbone
+    - rpn
+    - heads: takes the features + the proposals from the RPN and computes
+        detections / masks from it.
+    """
+
+    def __init__(self, cfg):
+        super(GeneralizedVLRCNN, self).__init__()
+        self.cfg = cfg
+
+        # visual encoder
+        self.backbone = build_backbone(cfg)
+
+        # language encoder
+        if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+            # self.tokenizer = build_tokenizer("clip")
+            from transformers import CLIPTokenizerFast
+            if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+                print("Reuse token 'ðŁĴij</w>' (token_id = 49404) for mask token!")
+                self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                            from_slow=True, mask_token='ðŁĴij</w>')
+            else:
+                self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                            from_slow=True)
+        else:
+            self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
+        self.tokenizer_vocab = self.tokenizer.get_vocab()
+        self.tokenizer_vocab_ids = [item for key, item in self.tokenizer_vocab.items()]
+
+        self.language_backbone = build_language_backbone(cfg)
+
+        self.rpn = build_rpn(cfg)
+        self.roi_heads = build_roi_heads(cfg)
+        self.DEBUG = cfg.MODEL.DEBUG
+
+        self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
+        self.freeze_fpn = cfg.MODEL.FPN.FREEZE
+        self.freeze_rpn = cfg.MODEL.RPN.FREEZE
+        self.add_linear_layer = cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER
+
+        self.force_boxes = cfg.MODEL.RPN.FORCE_BOXES
+
+        if cfg.MODEL.LINEAR_PROB:
+            assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
+            if hasattr(self.backbone, 'fpn'):
+                assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
+        self.linear_prob = cfg.MODEL.LINEAR_PROB
+        self.freeze_cls_logits = cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS
+        if cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            # disable cls_logits
+            if hasattr(self.rpn.head, 'cls_logits'):
+                for p in self.rpn.head.cls_logits.parameters():
+                    p.requires_grad = False
+
+        self.freeze_language_backbone = self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE
+        if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
+            for p in self.language_backbone.parameters():
+                p.requires_grad = False
+        
+        self.use_mlm_loss = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS 
+        self.mlm_loss_for_only_positives = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES
+
+        if self.cfg.GLIPKNOW.KNOWLEDGE_FILE:
+            from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file
+            self.class_name_to_knowledge = load_from_yaml_file(self.cfg.GLIPKNOW.KNOWLEDGE_FILE)
+            self.class_name_list = sorted([k for k in self.class_name_to_knowledge])
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(GeneralizedVLRCNN, self).train(mode)
+        if self.freeze_backbone:
+            self.backbone.body.eval()
+            for p in self.backbone.body.parameters():
+                p.requires_grad = False
+        if self.freeze_fpn:
+            self.backbone.fpn.eval()
+            for p in self.backbone.fpn.parameters():
+                p.requires_grad = False
+        if self.freeze_rpn:
+            if hasattr(self.rpn, 'head'):
+                self.rpn.head.eval()
+            for p in self.rpn.parameters():
+                p.requires_grad = False
+        if self.linear_prob:
+            if self.rpn is not None:
+                for key, value in self.rpn.named_parameters():
+                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key):
+                        value.requires_grad = False
+            if self.roi_heads is not None:
+                for key, value in self.roi_heads.named_parameters():
+                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key):
+                        value.requires_grad = False
+        if self.freeze_cls_logits:
+            if hasattr(self.rpn.head, 'cls_logits'):
+                self.rpn.head.cls_logits.eval()
+                for p in self.rpn.head.cls_logits.parameters():
+                    p.requires_grad = False
+        if self.add_linear_layer:
+            if self.rpn is not None:
+                for key, p in self.rpn.named_parameters():
+                    if 'tunable_linear' in key:
+                        p.requires_grad = True
+
+        if self.freeze_language_backbone:
+            self.language_backbone.eval()
+            for p in self.language_backbone.parameters():
+                p.requires_grad = False
+
+    def forward(self, 
+        images, 
+        targets=None, 
+        captions=None, 
+        positive_map=None,
+        greenlight_map=None):
+        """
+        Arguments:
+            images (list[Tensor] or ImageList): images to be processed
+            targets (list[BoxList]): ground-truth boxes present in the image (optional)
+
+            mask_black_list: batch x 256, indicates whether or not a certain token is maskable or not
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+        """
+        if self.training and targets is None:
+            raise ValueError("In training mode, targets should be passed")
+        
+        images = to_image_list(images)
+        # batch_size = images.tensors.shape[0]
+        device = images.tensors.device
+
+
+        if self.cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT:
+            language_dict_features, positive_map = self._forward_language_parallel(
+                    captions=captions, targets=targets, device=device,
+                    positive_map=positive_map)
+        else:
+            # language embedding
+            language_dict_features = {}
+            if captions is not None:
+                #print(captions[0])
+                tokenized = self.tokenizer.batch_encode_plus(captions,
+                                                            max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+                                                            padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest",
+                                                            return_special_tokens_mask=True,
+                                                            return_tensors='pt',
+                                                            truncation=True).to(device)
+                if self.use_mlm_loss:
+                    if not self.mlm_loss_for_only_positives:
+                        greenlight_map = None
+                    input_ids, mlm_labels = random_word(
+                        input_ids=tokenized.input_ids, 
+                        mask_token_id=self.tokenizer.mask_token_id,
+                        vocabs=self.tokenizer_vocab_ids,
+                        padding_token_id=self.tokenizer.pad_token_id,
+                        greenlight_map=greenlight_map)
+                else:
+                    input_ids = tokenized.input_ids
+                    mlm_labels = None
+                
+                
+                tokenizer_input = {"input_ids": input_ids,
+                                "attention_mask": tokenized.attention_mask}
+
+                if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
+                    with torch.no_grad():
+                        language_dict_features = self.language_backbone(tokenizer_input)
+                else:
+                    language_dict_features = self.language_backbone(tokenizer_input)
+                
+                # ONE HOT
+                if self.cfg.DATASETS.ONE_HOT:
+                    new_masks = torch.zeros_like(language_dict_features['masks'],
+                                                device=language_dict_features['masks'].device)
+                    new_masks[:, :self.cfg.MODEL.DYHEAD.NUM_CLASSES] = 1
+                    language_dict_features['masks'] = new_masks
+
+                # MASK ALL SPECIAL TOKENS
+                if self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL:
+                    language_dict_features["masks"] = 1 - tokenized.special_tokens_mask
+                
+                language_dict_features["mlm_labels"] = mlm_labels
+
+        # visual embedding
+        swint_feature_c4 = None
+        if 'vl' in self.cfg.MODEL.SWINT.VERSION:
+            # the backbone only updates the "hidden" field in language_dict_features
+            inputs = {"img": images.tensors, "lang": language_dict_features}
+            visual_features, language_dict_features, swint_feature_c4 = self.backbone(inputs)
+        else:
+            visual_features = self.backbone(images.tensors)
+
+        # rpn force boxes
+        if targets:
+            targets = [target.to(device)
+                       for target in targets if target is not None]
+
+        if self.force_boxes:
+            proposals = []
+            for t in targets:
+                tb = t.copy_with_fields(["labels"])
+                tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device))
+                proposals.append(tb)
+            if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
+                _, proposal_losses, fused_visual_features = self.rpn(
+                    images, visual_features, targets, language_dict_features,
+                    positive_map, captions, swint_feature_c4)
+            elif self.training:
+                null_loss = 0
+                for key, param in self.rpn.named_parameters():
+                    null_loss += 0.0 * param.sum()
+                proposal_losses = {('rpn_null_loss', null_loss)}
+        else:
+            proposals, proposal_losses, fused_visual_features = self.rpn(images, visual_features, targets, language_dict_features, positive_map,
+                                              captions, swint_feature_c4)
+        if self.roi_heads:
+            if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"):
+                if self.training:
+                    # "Only support VL mask head right now!!"
+                    assert len(targets) == 1 and len(targets[0]) == len(positive_map), "shape match assert for mask head!!"
+                    # Not necessary but as a safe guard:
+                    # use the binary 0/1 positive map to replace the normalized positive map
+                    targets[0].add_field("positive_map", positive_map)
+            # TODO: make sure that this use of language_dict_features is correct!! Its content should be changed in self.rpn
+            if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
+                x, result, detector_losses = self.roi_heads(
+                    fused_visual_features, proposals, targets,
+                    language_dict_features=language_dict_features,
+                    positive_map_label_to_token=positive_map if not self.training else None
+                )
+            else:
+                x, result, detector_losses = self.roi_heads(
+                    visual_features, proposals, targets,
+                    language_dict_features=language_dict_features,
+                    positive_map_label_to_token=positive_map if not self.training else None
+                )
+        else:
+            # RPN-only models don't have roi_heads
+            x = visual_features
+            result = proposals
+            detector_losses = {}
+
+        if self.training:
+            losses = {}
+            losses.update(detector_losses)
+            losses.update(proposal_losses)
+            return losses
+
+        return result
+
+    def _forward_language_parallel(self, captions=None, targets=None,
+            device=None, positive_map=None):
+        ktype = self.cfg.GLIPKNOW.KNOWLEDGE_TYPE
+        def _construct_captions_from_class_names(class_names):
+            captions = []
+            for c in class_names:
+                try:
+                    info = self.class_name_to_knowledge[c]
+                    cap = info['clean_name']
+
+                    # combine wiki and gpt3 knowledge
+                    if self.cfg.GLIPKNOW.WIKI_AND_GPT3:
+                        ktype = 'def_wiki'
+                        know_seq = info[ktype]
+
+                        ktype = 'gpt3'
+                        if ktype == 'gpt3' or type(info[ktype]) == list:
+                            know_seq += ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ])
+
+                        cap += ': ' + know_seq
+
+                    # only one knoweldge source is used        
+                    else:
+                        if ktype and ktype in info and info[ktype]:
+                            if ktype == 'gpt3' or type(info[ktype]) == list:
+                                know_seq = ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ])
+                            else: 
+                                know_seq = info[ktype]
+                            cap += ': ' + know_seq
+                except:
+                    cap = c
+                    print(f'cap {cap}, c {c}')
+                    
+                    
+                captions.append(cap)
+            return captions
+
+        if self.training:
+            assert captions is None
+            assert targets is not None
+
+            max_classes_per_batch = self.cfg.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN
+            if max_classes_per_batch >= len(self.class_name_list):
+                shuffled_class_names = self.class_name_list.copy()
+                random.shuffle(shuffled_class_names)
+                if max_classes_per_batch > len(shuffled_class_names):
+                    shuffled_class_names.extend(shuffled_class_names[:max_classes_per_batch
+                        -len(shuffled_class_names)])
+                    random.shuffle(shuffled_class_names)
+            else:
+                label_list = []
+                label_to_idx = {}
+                for target_per_im in targets:
+                    labels_per_im = target_per_im.get_field('label_names')
+                    for label in labels_per_im:
+                        if label not in label_to_idx:
+                            label_to_idx[label] = len(label_list)
+                            label_list.append(label)
+
+                label_list = label_list[:max_classes_per_batch]
+                if len(label_list) < max_classes_per_batch:
+                    all_neg_classes = [c for c in self.class_name_list if c not
+                            in label_to_idx]
+                    neg_label_list = random.sample(all_neg_classes,
+                            max_classes_per_batch - len(label_list))
+                    label_list.extend(neg_label_list)
+                random.shuffle(label_list)
+                shuffled_class_names = label_list
+
+            label_to_shuffled_idx = {l: i for i, l in
+                    enumerate(shuffled_class_names)}
+            total_boxes = sum(len(t) for t in targets)
+            positive_map = torch.zeros((total_boxes, max_classes_per_batch+1),
+                device=device)
+            offset = 0
+            for target_per_im in targets:
+                labels_per_im = target_per_im.get_field('label_names')
+                for label in labels_per_im:
+                    j = label_to_shuffled_idx.get(label, -1)
+                    if j >= 0:
+                        positive_map[offset, j] = 1
+                    offset += 1
+            captions = _construct_captions_from_class_names(shuffled_class_names)
+            captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719
+            batch_size = len(targets)
+
+        else:
+            assert captions is not None
+            batch_size = 1
+            assert len(captions) == 1
+            class_names = captions[0]
+            max_classes_per_batch = len(class_names)
+            captions = _construct_captions_from_class_names(class_names)
+            captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719
+
+        tokenized = self.tokenizer.batch_encode_plus(captions,
+                                                     max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+                                                     padding="longest",
+                                                     return_special_tokens_mask=True,
+                                                     return_tensors='pt',
+                                                     truncation=True).to(device)
+        assert not self.use_mlm_loss
+        tokenizer_input = {"input_ids": tokenized.input_ids,
+                           "attention_mask": tokenized.attention_mask}
+
+        if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
+            with torch.no_grad():
+                language_dict_features = self.language_backbone(tokenizer_input)
+        else:
+            language_dict_features = self.language_backbone(tokenizer_input)
+
+        assert not self.cfg.DATASETS.ONE_HOT
+        assert not self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL
+
+        agg_type = self.cfg.GLIPKNOW.LAN_FEATURE_AGG_TYPE
+        agg_feats = language_dict_features['hidden']
+        agg_emb = language_dict_features['embedded']
+        if agg_type == 'first':
+            agg_feats = agg_feats[:, 0, :]
+            agg_emb = agg_emb[:, 0, :]
+        elif agg_type == 'mean':
+            attn_mask = language_dict_features['masks']
+            seq_len = attn_mask.sum(-1).unsqueeze(-1).float()
+            agg_feats = agg_feats * attn_mask.unsqueeze(-1).float()
+            agg_feats = agg_feats.sum(1) / seq_len
+            agg_emb = agg_emb * attn_mask.unsqueeze(-1).float()
+            agg_emb = agg_emb.sum(1) / seq_len
+        else:
+            raise ValueError('not supported GLIPKNOW.LAN_FEATURE_AGG_TYPE: {}'.format(agg_type))
+
+        expanded_features = agg_feats.unsqueeze(0).repeat(batch_size, 1, 1)
+        expanded_embedding = agg_emb.unsqueeze(0).repeat(batch_size, 1, 1)
+
+        lang_dict = {}
+        lang_dict["mlm_labels"] = None
+        lang_dict["aggregate"] = None
+        lang_dict["embedded"] = expanded_embedding
+        lang_dict['hidden'] = expanded_features
+        lang_dict["masks"] = torch.ones((batch_size, max_classes_per_batch+1),
+                device=device, dtype=language_dict_features['masks'].dtype)
+        # in GLIP setting, the token at the end of seqence is usually [PAD], and is masked out
+        # if [noobj] is not masked out, the loss sum is very big, as most
+        # anchors are matched to [noobj]
+        lang_dict["masks"][:,-1] = 0
+        return lang_dict, positive_map
+
diff --git a/maskrcnn_benchmark/modeling/language_backbone/__init__.py b/maskrcnn_benchmark/modeling/language_backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78d6ab1d5b2d59007bb4c042d0fc1a5a06253da
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/__init__.py
@@ -0,0 +1,6 @@
+from .backbone import build_backbone as build_language_backbone
+from .build import build_tokenizer
+
+from .hfpt_tokenizer import HFPTTokenizer
+from .simple_tokenizer import SimpleTokenizer
+from .clip_model import CLIPTransformer
diff --git a/maskrcnn_benchmark/modeling/language_backbone/backbone.py b/maskrcnn_benchmark/modeling/language_backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..632b622092b52297c690cd9c0cebcef48b842e48
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/backbone.py
@@ -0,0 +1,45 @@
+from collections import OrderedDict
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.modeling import registry
+from . import bert_model
+from . import rnn_model
+from . import clip_model
+from . import word_utils
+
+
+@registry.LANGUAGE_BACKBONES.register("bert-base-uncased")
+def build_bert_backbone(cfg):
+    body = bert_model.BertEncoder(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+@registry.LANGUAGE_BACKBONES.register("roberta-base")
+def build_bert_backbone(cfg):
+    body = bert_model.BertEncoder(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+@registry.LANGUAGE_BACKBONES.register("rnn")
+def build_rnn_backbone(cfg):
+    body = rnn_model.RNNEnoder(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+@registry.LANGUAGE_BACKBONES.register("clip")
+def build_clip_backbone(cfg):
+    body = clip_model.CLIPTransformer(cfg)
+    model = nn.Sequential(OrderedDict([("body", body)]))
+    return model
+
+
+def build_backbone(cfg):
+    assert cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in registry.LANGUAGE_BACKBONES, \
+        "cfg.MODEL.LANGUAGE_BACKBONE.TYPE: {} is not registered in registry".format(
+            cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
+        )
+    return registry.LANGUAGE_BACKBONES[cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE](cfg)
diff --git a/maskrcnn_benchmark/modeling/language_backbone/bert_model.py b/maskrcnn_benchmark/modeling/language_backbone/bert_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b69c54fc06ef600351da4addae354d971afb0e1
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/bert_model.py
@@ -0,0 +1,79 @@
+from copy import deepcopy
+import numpy as np
+import torch
+from torch import nn
+
+# from pytorch_pretrained_bert.modeling import BertModel
+from transformers import BertConfig, RobertaConfig, RobertaModel, BertModel
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, cfg):
+        super(BertEncoder, self).__init__()
+        self.cfg = cfg
+        self.bert_name = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
+        print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT)
+
+        if self.bert_name == "bert-base-uncased":
+            config = BertConfig.from_pretrained(self.bert_name)
+            config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
+            self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config)
+            self.language_dim = 768
+        elif self.bert_name == "roberta-base":
+            config = RobertaConfig.from_pretrained(self.bert_name)
+            config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
+            self.model = RobertaModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config)
+            self.language_dim = 768
+        else:
+            raise NotImplementedError
+
+        self.num_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
+
+    def forward(self, x):
+        input = x["input_ids"]
+        mask = x["attention_mask"]
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            # with padding, always 256
+            outputs = self.model(
+                input_ids=input,
+                attention_mask=mask,
+                output_hidden_states=True,
+            )
+            # outputs has 13 layers, 1 input layer and 12 hidden layers
+            encoded_layers = outputs.hidden_states[1:]
+            features = None
+            features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1)
+
+            # language embedding has shape [len(phrase), seq_len, language_dim]
+            features = features / self.num_layers
+
+            embedded = features * mask.unsqueeze(-1).float()
+            aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
+
+        else:
+            # without padding, only consider positive_tokens
+            max_len = (input != 0).sum(1).max().item()
+            outputs = self.model(
+                input_ids=input[:, :max_len],
+                attention_mask=mask[:, :max_len],
+                output_hidden_states=True,
+            )
+            # outputs has 13 layers, 1 input layer and 12 hidden layers
+            encoded_layers = outputs.hidden_states[1:]
+
+            features = None
+            features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1)
+            # language embedding has shape [len(phrase), seq_len, language_dim]
+            features = features / self.num_layers
+
+            embedded = features * mask[:, :max_len].unsqueeze(-1).float()
+            aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
+
+        ret = {
+            "aggregate": aggregate,
+            "embedded": embedded,
+            "masks": mask,
+            "hidden": encoded_layers[-1]
+        }
+        return ret
diff --git a/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz b/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..e74ad860329b14ff6b53f3ae0b007bec308cc5af
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc496842c2d4b6e40b2bd1207a5ded6e425e6a7cf9c16afa86caa5d7d12df233
+size 1355337
diff --git a/maskrcnn_benchmark/modeling/language_backbone/build.py b/maskrcnn_benchmark/modeling/language_backbone/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5fc534df7864869d89734b7ca48ba6d56fe5a58
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/build.py
@@ -0,0 +1,18 @@
+from .simple_tokenizer import SimpleTokenizer
+
+
+def build_tokenizer(tokenizer_name):
+    tokenizer = None
+    if tokenizer_name == 'clip':
+        tokenizer = SimpleTokenizer()
+    elif 'hf_' in tokenizer_name:
+        from .hfpt_tokenizer import HFPTTokenizer
+
+        tokenizer = HFPTTokenizer(pt_name=tokenizer_name[3:])
+    elif 'hfc_' in tokenizer_name:
+        from .hfpt_tokenizer import HFPTTokenizer
+        tokenizer = HFPTTokenizer(pt_name=tokenizer_name[4:])
+    else:
+        raise ValueError('Unknown tokenizer')
+
+    return tokenizer
diff --git a/maskrcnn_benchmark/modeling/language_backbone/clip_model.py b/maskrcnn_benchmark/modeling/language_backbone/clip_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..781f4f4ac5dabd7d232741fe88d40785ee2c1919
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/clip_model.py
@@ -0,0 +1,200 @@
+from collections import OrderedDict
+import logging
+import os
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from maskrcnn_benchmark.config import try_to_find
+
+from timm.models.layers import DropPath, trunc_normal_
+
+logger = logging.getLogger(__name__)
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-12):
+        """Construct a layernorm module in the TF style (epsilon inside the square root).
+        """
+        super(LayerNorm, self).__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.bias = nn.Parameter(torch.zeros(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, x):
+        pdtype = x.dtype
+        x = x.float()
+        u = x.mean(-1, keepdim=True)
+        s = (x - u).pow(2).mean(-1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+        return self.weight * x.to(pdtype) + self.bias
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self,
+                 d_model: int,
+                 n_head: int,
+                 attn_mask: torch.Tensor = None,
+                 drop_path: float = 0.0):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
+            if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
+
+    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
+        x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
+        x = x + self.drop_path(self.mlp(self.ln_2(x)))
+        return x
+
+
+class CLIPTransformer(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+
+        self.cfg = cfg
+        self.use_checkpoint = cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
+        print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT)
+
+        self.context_length = self.cfg.MODEL.CLIP.CONTEXT_LENGTH
+        self.width = self.cfg.MODEL.CLIP.WIDTH
+        self.layers = self.cfg.MODEL.CLIP.LAYERS
+        self.heads = self.cfg.MODEL.CLIP.HEADS
+        self.drop_path = self.cfg.MODEL.CLIP.DROP_PATH
+        self.vocab_size = self.cfg.MODEL.CLIP.VOCAB_SIZE
+
+        self.token_embedding = nn.Embedding(self.vocab_size, self.width)
+
+        self.positional_embedding = nn.Parameter(
+            torch.empty(self.context_length, self.width)
+        )
+
+        # attn_mask = self.build_attention_mask()
+        attn_mask = None
+
+        dpr = [x.item() for x in torch.linspace(0, self.drop_path, self.layers)]  # stochastic depth decay rule
+        self.resblocks = nn.ModuleList(
+            [
+                ResidualAttentionBlock(self.width, self.heads, attn_mask, dpr[i])
+                for i in range(self.layers)
+            ]
+        )
+
+        self.ln_final = LayerNorm(self.width)
+
+        trunc_normal_(self.positional_embedding, std=.02)
+        # nn.init.normal_(self.token_embedding, std=.02)
+        trunc_normal_(self.token_embedding.weight, std=.02)
+        self.apply(self._init_weights)
+
+        # loading pre-trained weight from our CLIP models
+        if len(self.cfg.MODEL.LANGUAGE_BACKBONE.WEIGHT) > 0:
+            self.init_weights(pretrained=try_to_find(self.cfg.MODEL.LANGUAGE_BACKBONE.WEIGHT),
+                              pretrained_layers=['*'])
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Linear, nn.Conv2d)):
+            trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
+            nn.init.constant_(m.bias, 0)
+
+    def resize_pos_embed_1d(self, posemb, shape_new):
+        # rescale the grid of position embeddings when loading from state_dict
+        ntok_old = posemb.shape[0]
+        if ntok_old > 1:
+            ntok_new = shape_new[0]
+            posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1).unsqueeze(dim=-1)
+            posemb_grid = F.interpolate(posemb_grid, size=[ntok_new, 1], mode='bilinear')
+            posemb_grid = posemb_grid.squeeze(dim=-1).permute(0, 2, 1).squeeze(dim=0)
+            posemb = posemb_grid
+        return posemb
+
+    def init_weights(self, pretrained="", pretrained_layers=[], verbose=False):
+        if os.path.isfile(pretrained):
+            pretrained_dict = torch.load(pretrained, map_location="cpu")
+            logger.info(f'=> loading pretrained clip text model {pretrained}')
+            model_dict = self.state_dict()
+
+            need_init_state_dict = {}
+            for k, v in pretrained_dict.items():
+                need_init = (
+                        k.split('.')[0] in pretrained_layers
+                        or pretrained_layers[0] is '*'
+                )
+                if need_init:
+                    if k.startswith('text.') and k[5:] in model_dict.keys():
+                        need_init_state_dict[k[5:]] = v
+
+            # notice the context length now changes from 77 to 256, so we need to resize the positional embedding
+            if "positional_embedding" in need_init_state_dict.keys():
+                old_pos_embed = need_init_state_dict["positional_embedding"].float()
+                new_pos_embed = self.resize_pos_embed_1d(old_pos_embed,
+                                                         (self.cfg.MODEL.CLIP.CONTEXT_LENGTH, old_pos_embed.shape[1]))
+                need_init_state_dict["positional_embedding"] = new_pos_embed
+            self.load_state_dict(need_init_state_dict, strict=True)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {
+            'positional_embedding',
+            'token_embedding',
+        }
+
+    def forward(self, text):
+        input = text["input_ids"]
+        mask = text["attention_mask"]
+        # get extended attention mask for nn.MultiHeadAttention
+        key_padding_mask = (1.0 - mask).to(torch.bool)
+
+        x = self.token_embedding(input)  # [batch_size, n_ctx, d_model]
+        x = x + self.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        for resblock in self.resblocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(resblock, x, key_padding_mask)
+            else:
+                x = resblock(x, key_padding_mask)
+
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_final(x)
+
+        # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+
+        ret = {
+            "aggregate": x,
+            "embedded": x,
+            "masks": mask,
+            "hidden": x
+        }
+
+        return ret
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..06dce89d75e3b91ee3405dd2e449b9d48dc861f2
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py
@@ -0,0 +1,99 @@
+from typing import Union, List
+
+from transformers import AutoTokenizer
+import torch
+
+
+class HFPTTokenizer(object):
+    def __init__(self, pt_name=None):
+
+        self.pt_name = pt_name
+        self.added_sep_token = 0
+        self.added_cls_token = 0
+        self.enable_add_tokens = False
+        self.gpt_special_case = ((not self.enable_add_tokens) and ('gpt' in self.pt_name))
+
+        if (pt_name is None):
+            self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
+        else:
+            self.tokenizer = AutoTokenizer.from_pretrained(pt_name)
+
+        # Adding tokens to GPT causing NaN training loss.
+        # Disable for now until further investigation.
+        if (self.enable_add_tokens):
+            if (self.tokenizer.sep_token is None):
+                self.tokenizer.add_special_tokens({'sep_token': '<SEP>'})
+                self.added_sep_token = 1
+
+            if (self.tokenizer.cls_token is None):
+                self.tokenizer.add_special_tokens({'cls_token': '<CLS>'})
+                self.added_cls_token = 1
+
+        if (self.gpt_special_case):
+            self.tokenizer.pad_token = self.tokenizer.eos_token
+            self.tokenizer.sep_token = self.tokenizer.eos_token
+
+    def get_eot_token(self):
+        return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)[0]
+
+    def get_sot_token(self):
+        return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)[0]
+
+    def get_eot_token_list(self):
+        return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)
+
+    def get_sot_token_list(self):
+        return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)
+
+    def get_tokenizer_obj(self):
+        return self.tokenizer
+
+    # Language model needs to know if new tokens
+    # were added to the dictionary.
+    def check_added_tokens(self):
+        return self.added_sep_token + self.added_cls_token
+
+    def tokenize(self, texts: Union[str, List[str]], context_length: int = 77):
+        if isinstance(texts, str):
+            texts = [texts]
+
+        padding = 'max_length'
+
+        seqstart = []
+        seqtok = []
+        seqend = []
+
+        max_length = context_length
+
+        if (self.added_cls_token > 0):
+            seqstart = self.get_sot_token_list()
+            max_length = max_length - 1
+
+        if (self.added_sep_token > 0):
+            seqend = self.get_eot_token_list()
+            max_length = max_length - 1
+
+        tokens = self.tokenizer(
+            texts, padding=padding,
+            truncation=True,
+            max_length=max_length
+        )['input_ids']
+
+        for i in range(len(tokens)):
+            tokens[i] = seqstart + tokens[i] + seqend
+
+        if (self.gpt_special_case):
+            for i in range(len(tokens)):
+                tokens[i][-1] = self.get_eot_token()
+
+        # print(str(tokens))
+
+        result = torch.Tensor(tokens).type(torch.LongTensor)
+
+        return result
+
+    def get_vocab_size(self):
+        return self.tokenizer.vocab_size
+
+    def __call__(self, texts: Union[str, List[str]], context_length: int = 77):
+        return self.tokenize(texts, context_length)
diff --git a/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py b/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d60efcb08675b73bce211e42c0c180ffe8d267
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py
@@ -0,0 +1,115 @@
+from copy import deepcopy
+import numpy as np
+import torch
+from torch import nn
+
+
+class RNNEnoder(nn.Module):
+    def __init__(self, cfg):
+        super(RNNEnoder, self).__init__()
+        self.cfg = cfg
+
+        self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE
+        self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH
+        self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE
+        self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE
+        self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE
+        self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL
+        self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P
+        self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P
+        self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
+        self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH
+        self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE
+
+        # language encoder
+        self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size)
+        self.input_dropout = nn.Dropout(self.input_dropout_p)
+        self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU())
+        self.rnn = getattr(nn, self.rnn_type.upper())(self.word_vec_size,
+                                                      self.hidden_size,
+                                                      self.n_layers,
+                                                      batch_first=True,
+                                                      bidirectional=self.bidirectional,
+                                                      dropout=self.dropout_p)
+        self.num_dirs = 2 if self.bidirectional else 1
+
+    def forward(self, input, mask=None):
+        word_id = input
+        max_len = (word_id != 0).sum(1).max().item()
+        word_id = word_id[:, :max_len]  # mask zero
+        # embedding
+        output, hidden, embedded, final_output = self.RNNEncode(word_id)
+        return {
+            'hidden': hidden,
+            'output': output,
+            'embedded': embedded,
+            'final_output': final_output,
+        }
+
+    def encode(self, input_labels):
+        """
+                Inputs:
+                - input_labels: Variable long (batch, seq_len)
+                Outputs:
+                - output  : Variable float (batch, max_len, hidden_size * num_dirs)
+                - hidden  : Variable float (batch, num_layers * num_dirs * hidden_size)
+                - embedded: Variable float (batch, max_len, word_vec_size)
+                """
+        device = input_labels.device
+        if self.variable_length:
+            input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels)
+            input_labels = input_labels[sort_idxs]
+
+        embedded = self.embedding(input_labels)  # (n, seq_len, word_embedding_size)
+        embedded = self.input_dropout(embedded)  # (n, seq_len, word_embedding_size)
+        embedded = self.mlp(embedded)  # (n, seq_len, word_vec_size)
+
+        if self.variable_length:
+            if self.variable_length:
+                embedded = nn.utils.rnn.pack_padded_sequence(embedded, \
+                                                             sorted_lengths_list, \
+                                                             batch_first=True)
+        # forward rnn
+        self.rnn.flatten_parameters()
+        output, hidden = self.rnn(embedded)
+
+        # recover
+        if self.variable_length:
+            # recover embedded
+            embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded,
+                                                           batch_first=True)  # (batch, max_len, word_vec_size)
+            embedded = embedded[recover_idxs]
+
+            # recover output
+            output, _ = nn.utils.rnn.pad_packed_sequence(output,
+                                                         batch_first=True)  # (batch, max_len, hidden_size * num_dir)
+            output = output[recover_idxs]
+
+            # recover hidden
+            if self.rnn_type == 'lstm':
+                hidden = hidden[0]  # hidden state
+            hidden = hidden[:, recover_idxs, :]  # (num_layers * num_dirs, batch, hidden_size)
+            hidden = hidden.transpose(0, 1).contiguous()  # (batch, num_layers * num_dirs, hidden_size)
+            hidden = hidden.view(hidden.size(0), -1)  # (batch, num_layers * num_dirs * hidden_size)
+
+        # final output
+        finnal_output = []
+        for ii in range(output.shape[0]):
+            finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :])
+        finnal_output = torch.stack(finnal_output, dim=0)  # (batch, number_dirs * hidden_size)
+
+        return output, hidden, embedded, finnal_output
+
+    def sort_inputs(self, input_labels):  # sort input labels by descending
+        device = input_labels.device
+        input_lengths = (input_labels != 0).sum(1)
+        input_lengths_list = input_lengths.data.cpu().numpy().tolist()
+        sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist()  # list of sorted input_lengths
+        sort_idxs = np.argsort(input_lengths_list)[::-1].tolist()
+        s2r = {s: r for r, s in enumerate(sort_idxs)}
+        recover_idxs = [s2r[s] for s in range(len(input_lengths_list))]
+        assert max(input_lengths_list) == input_labels.size(1)
+        # move to long tensor
+        sort_idxs = input_labels.data.new(sort_idxs).long().to(device)  # Variable long
+        recover_idxs = input_labels.data.new(recover_idxs).long().to(device)  # Variable long
+        return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs
diff --git a/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8653b554bce885162452067b67359f07eb022174
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py
@@ -0,0 +1,173 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+from typing import Union, List
+
+import torch
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+    cs = bs[:]
+    n = 0
+    for b in range(2 ** 8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2 ** 8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152 - 256 - 2 + 1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v + '</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + (token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token + '</w>'
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text
+
+    def get_vocab_size(self):
+        return 49408
+
+    def get_eot_token(self):
+        return self.encoder["<|endoftext|>"]
+
+    def get_sot_token(self):
+        return self.encoder["<|startoftext|>"]
+
+    def check_added_tokens(self):
+        return 0
+
+    def get_tokenizer_obj(self):
+        return None
+
+    def tokenize(self, texts: Union[str, List[str]], context_length: int = 77):
+        if isinstance(texts, str):
+            texts = [texts]
+
+        sot_token = self.encoder["<|startoftext|>"]
+        eot_token = self.encoder["<|endoftext|>"]
+        all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+        for i, tokens in enumerate(all_tokens):
+            if len(tokens) > context_length:
+                tokens = tokens[:context_length]
+                # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+
+            result[i, :len(tokens)] = torch.tensor(tokens)
+
+        return result
+
+    def __call__(self, texts: Union[str, List[str]], context_length: int = 77):
+        return self.tokenize(texts, context_length)
diff --git a/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5c63c73cfba5bd3a580e93852cd9c91fac00b35
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py
@@ -0,0 +1,8 @@
+from maskrcnn_benchmark.modeling.language_backbone import build_tokenizer
+
+if __name__ == '__main__':
+
+    tokenizer2 = build_tokenizer("clip")
+    tokenized2 = tokenizer2(
+        ["Detectest : fishid. jellyfishioasod. penguinasd. puffin.asd shark. starfish. round stingray"])
+    print(tokenized2)
diff --git a/maskrcnn_benchmark/modeling/language_backbone/word_utils.py b/maskrcnn_benchmark/modeling/language_backbone/word_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f453ba70bf8832f3f4124d82467b3803b09af1
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/language_backbone/word_utils.py
@@ -0,0 +1,100 @@
+"""
+Language-related data loading helper functions and class wrappers.
+"""
+
+import re
+import torch
+import codecs
+
+UNK_TOKEN = '<unk>'
+PAD_TOKEN = '<pad>'
+END_TOKEN = '<eos>'
+SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)')
+
+
+class Dictionary(object):
+    def __init__(self):
+        self.word2idx = {}
+        self.idx2word = []
+
+    def add_word(self, word):
+        if word not in self.word2idx:
+            self.idx2word.append(word)
+            self.word2idx[word] = len(self.idx2word) - 1
+        return self.word2idx[word]
+
+    def __len__(self):
+        return len(self.idx2word)
+
+    def __getitem__(self, a):
+        if isinstance(a, int):
+            return self.idx2word[a]
+        elif isinstance(a, list):
+            return [self.idx2word[x] for x in a]
+        elif isinstance(a, str):
+            return self.word2idx[a]
+        else:
+            raise TypeError("Query word/index argument must be int or str")
+
+    def __contains__(self, word):
+        return word in self.word2idx
+
+
+class Corpus(object):
+    def __init__(self):
+        self.dictionary = Dictionary()
+
+    def set_max_len(self, value):
+        self.max_len = value
+
+    def load_file(self, filename):
+        with codecs.open(filename, 'r', 'utf-8') as f:
+            for line in f:
+                line = line.strip()
+                self.add_to_corpus(line)
+        self.dictionary.add_word(UNK_TOKEN)
+        self.dictionary.add_word(PAD_TOKEN)
+
+    def add_to_corpus(self, line):
+        """Tokenizes a text line."""
+        # Add words to the dictionary
+        words = line.split()
+        # tokens = len(words)
+        for word in words:
+            word = word.lower()
+            self.dictionary.add_word(word)
+
+    def tokenize(self, line, max_len=20):
+        # Tokenize line contents
+        words = SENTENCE_SPLIT_REGEX.split(line.strip())
+        # words = [w.lower() for w in words if len(w) > 0]
+        words = [w.lower() for w in words if (len(w) > 0 and w != ' ')]  ## do not include space as a token
+
+        if words[-1] == '.':
+            words = words[:-1]
+
+        if max_len > 0:
+            if len(words) > max_len:
+                words = words[:max_len]
+            elif len(words) < max_len:
+                # words = [PAD_TOKEN] * (max_len - len(words)) + words
+                words = words + [END_TOKEN] + [PAD_TOKEN] * (max_len - len(words) - 1)
+
+        tokens = len(words)  ## for end token
+        ids = torch.LongTensor(tokens)
+        token = 0
+        for word in words:
+            if word not in self.dictionary:
+                word = UNK_TOKEN
+            # print(word, type(word), word.encode('ascii','ignore').decode('ascii'), type(word.encode('ascii','ignore').decode('ascii')))
+            if type(word) != type('a'):
+                print(word, type(word), word.encode('ascii', 'ignore').decode('ascii'),
+                      type(word.encode('ascii', 'ignore').decode('ascii')))
+                word = word.encode('ascii', 'ignore').decode('ascii')
+            ids[token] = self.dictionary[word]
+            token += 1
+        # ids[token] = self.dictionary[END_TOKEN]
+        return ids
+
+    def __len__(self):
+        return len(self.dictionary)
diff --git a/maskrcnn_benchmark/modeling/make_layers.py b/maskrcnn_benchmark/modeling/make_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..2216a952d04a295d0cf474d2f562903081fe0ea6
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/make_layers.py
@@ -0,0 +1,124 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Miscellaneous utility functions
+"""
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from maskrcnn_benchmark.config import cfg
+from maskrcnn_benchmark.layers import Conv2d, DYReLU
+from maskrcnn_benchmark.modeling.poolers import Pooler
+
+
+def get_group_gn(dim, dim_per_gp, num_groups):
+    """get number of groups used by GroupNorm, based on number of channels."""
+    assert dim_per_gp == -1 or num_groups == -1, \
+        "GroupNorm: can only specify G or C/G."
+
+    if dim_per_gp > 0:
+        assert dim % dim_per_gp == 0, \
+            "dim: {}, dim_per_gp: {}".format(dim, dim_per_gp)
+        group_gn = dim // dim_per_gp
+    else:
+        assert dim % num_groups == 0, \
+            "dim: {}, num_groups: {}".format(dim, num_groups)
+        group_gn = num_groups
+
+    return group_gn
+
+
+def group_norm(out_channels, affine=True, divisor=1):
+    out_channels = out_channels // divisor
+    dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor
+    num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor
+    eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5
+    return torch.nn.GroupNorm(
+        get_group_gn(out_channels, dim_per_gp, num_groups), 
+        out_channels, 
+        eps, 
+        affine
+    )
+
+
+def make_conv3x3(
+    in_channels, 
+    out_channels, 
+    dilation=1, 
+    stride=1, 
+    use_gn=False,
+    use_relu=False,
+    kaiming_init=True
+):
+    conv = Conv2d(
+        in_channels, 
+        out_channels, 
+        kernel_size=3, 
+        stride=stride, 
+        padding=dilation, 
+        dilation=dilation, 
+        bias=False if use_gn else True
+    )
+    if kaiming_init:
+        nn.init.kaiming_normal_(
+            conv.weight, mode="fan_out", nonlinearity="relu"
+        )
+    else:
+        torch.nn.init.normal_(conv.weight, std=0.01)
+    if not use_gn:
+        nn.init.constant_(conv.bias, 0)
+    module = [conv,]
+    if use_gn:
+        module.append(group_norm(out_channels))
+    if use_relu:
+        module.append(nn.ReLU(inplace=True))
+    if len(module) > 1:
+        return nn.Sequential(*module)
+    return conv
+
+
+def make_fc(dim_in, hidden_dim, use_gn=False):
+    '''
+        Caffe2 implementation uses XavierFill, which in fact
+        corresponds to kaiming_uniform_ in PyTorch
+    '''
+    if use_gn:
+        fc = nn.Linear(dim_in, hidden_dim, bias=False)
+        nn.init.kaiming_uniform_(fc.weight, a=1)
+        return nn.Sequential(fc, group_norm(hidden_dim))
+    fc = nn.Linear(dim_in, hidden_dim)
+    nn.init.kaiming_uniform_(fc.weight, a=1)
+    nn.init.constant_(fc.bias, 0)
+    return fc
+
+
+def conv_with_kaiming_uniform(use_gn=False, use_relu=False, use_dyrelu=False):
+    def make_conv(
+        in_channels, out_channels, kernel_size, stride=1, dilation=1
+    ):
+        conv = Conv2d(
+            in_channels, 
+            out_channels, 
+            kernel_size=kernel_size, 
+            stride=stride, 
+            padding=dilation * (kernel_size - 1) // 2, 
+            dilation=dilation, 
+            bias=False if use_gn else True
+        )
+        # Caffe2 implementation uses XavierFill, which in fact
+        # corresponds to kaiming_uniform_ in PyTorch
+        nn.init.kaiming_uniform_(conv.weight, a=1)
+        if not use_gn:
+            nn.init.constant_(conv.bias, 0)
+        module = [conv,]
+        if use_gn:
+            module.append(group_norm(out_channels))
+        if use_relu:
+            module.append(nn.ReLU(inplace=True))
+        if use_dyrelu:
+            module.append(DYReLU(out_channels, out_channels, use_spatial=True))
+        if len(module) > 1:
+            return nn.Sequential(*module)
+        return conv
+
+    return make_conv
diff --git a/maskrcnn_benchmark/modeling/matcher.py b/maskrcnn_benchmark/modeling/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..d080b0546b8e3e581ced4fbac89cca4dfde78b1a
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/matcher.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+
+class Matcher(object):
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        assert low_threshold <= high_threshold
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix):
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                # raise ValueError(
+                #     "No ground-truth boxes available for one of the images "
+                #     "during training")
+                length = match_quality_matrix.size(1)
+                device = match_quality_matrix.device
+                return torch.ones(length, dtype=torch.int64, device=device) * -1
+            else:
+                raise ValueError(
+                    "No proposal boxes available for one of the images "
+                    "during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (
+            matched_vals < self.high_threshold
+        )
+        matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.nonzero(
+            match_quality_matrix == highest_quality_foreach_gt[:, None]
+        )
+        # Example gt_pred_pairs_of_highest_quality:
+        #   tensor([[    0, 39796],
+        #           [    1, 32055],
+        #           [    1, 32070],
+        #           [    2, 39190],
+        #           [    2, 40255],
+        #           [    3, 40390],
+        #           [    3, 41455],
+        #           [    4, 45470],
+        #           [    5, 45325],
+        #           [    5, 46390]])
+        # Each row is a (gt index, prediction index)
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
diff --git a/maskrcnn_benchmark/modeling/poolers.py b/maskrcnn_benchmark/modeling/poolers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad136731b58a97bbf3d8266ee301d1c930c8fa6e
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/poolers.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.layers import ROIAlign, ROIAlignV2
+
+from .utils import cat
+
+
+class LevelMapper(object):
+    """Determine which FPN level each RoI in a set of RoIs should map to based
+    on the heuristic in the FPN paper.
+    """
+
+    def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
+        """
+        Arguments:
+            k_min (int)
+            k_max (int)
+            canonical_scale (int)
+            canonical_level (int)
+            eps (float)
+        """
+        self.k_min = k_min
+        self.k_max = k_max
+        self.s0 = canonical_scale
+        self.lvl0 = canonical_level
+        self.eps = eps
+
+    def __call__(self, boxlists):
+        """
+        Arguments:
+            boxlists (list[BoxList])
+        """
+        # Compute level ids
+        s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))
+
+        # Eqn.(1) in FPN paper
+        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
+        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
+        return target_lvls.to(torch.int64) - self.k_min
+
+
+class Pooler(nn.Module):
+    """
+    Pooler for Detection with or without FPN.
+    It currently hard-code ROIAlign in the implementation,
+    but that can be made more generic later on.
+    Also, the requirement of passing the scales is not strictly necessary, as they
+    can be inferred from the size of the feature map / size of original image,
+    which is available thanks to the BoxList.
+    """
+
+    def __init__(self, output_size, scales, sampling_ratio, use_v2=False):
+        """
+        Arguments:
+            output_size (list[tuple[int]] or list[int]): output size for the pooled region
+            scales (list[float]): scales for each Pooler
+            sampling_ratio (int): sampling ratio for ROIAlign
+        """
+        super(Pooler, self).__init__()
+        poolers = []
+        for scale in scales:
+            poolers.append(
+                    ROIAlignV2(
+                        output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
+                    )
+                    if use_v2 else
+                    ROIAlign(
+                        output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
+                    )
+            )
+        self.poolers = nn.ModuleList(poolers)
+        self.output_size = output_size
+        # get the levels in the feature map by leveraging the fact that the network always
+        # downsamples by a factor of 2 at each level.
+        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
+        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
+        self.map_levels = LevelMapper(lvl_min, lvl_max)
+
+    def convert_to_roi_format(self, boxes):
+        concat_boxes = cat([b.bbox for b in boxes], dim=0)
+        device, dtype = concat_boxes.device, concat_boxes.dtype
+        ids = cat(
+            [
+                torch.full((len(b), 1), i, dtype=dtype, device=device)
+                for i, b in enumerate(boxes)
+            ],
+            dim=0,
+        )
+        rois = torch.cat([ids, concat_boxes], dim=1)
+        return rois
+
+    def forward(self, x, boxes):
+        """
+        Arguments:
+            x (list[Tensor]): feature maps for each level
+            boxes (list[BoxList]): boxes to be used to perform the pooling operation.
+        Returns:
+            result (Tensor)
+        """
+        num_levels = len(self.poolers)
+        rois = self.convert_to_roi_format(boxes)
+        if num_levels == 1:
+            return self.poolers[0](x[0], rois)
+
+        levels = self.map_levels(boxes)
+
+        num_rois = len(rois)
+        num_channels = x[0].shape[1]
+        output_size = self.output_size[0]
+
+        dtype, device = x[0].dtype, x[0].device
+        result = torch.zeros(
+            (num_rois, num_channels, output_size, output_size),
+            dtype=dtype,
+            device=device,
+        )
+        for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
+            idx_in_level = torch.nonzero(levels == level).squeeze(1)
+            rois_per_level = rois[idx_in_level]
+            result[idx_in_level] = pooler(per_level_feature, rois_per_level)
+
+        return result
diff --git a/maskrcnn_benchmark/modeling/registry.py b/maskrcnn_benchmark/modeling/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d828cdbb550a242a2b2a944fc1c7efccbe9da90
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/registry.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from maskrcnn_benchmark.utils.registry import Registry
+
+BACKBONES = Registry()
+
+LANGUAGE_BACKBONES = Registry()
+
+ROI_BOX_FEATURE_EXTRACTORS = Registry()
+RPN_HEADS = Registry()
diff --git a/maskrcnn_benchmark/modeling/roi_heads/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6b92b4e6adc2c9b592f4cdee794b36b57a4548
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/__init__.py
@@ -0,0 +1,84 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+from .box_head.box_head import build_roi_box_head
+from .mask_head.mask_head import build_roi_mask_head
+from .keypoint_head.keypoint_head import build_roi_keypoint_head
+
+
+class CombinedROIHeads(torch.nn.ModuleDict):
+    """
+    Combines a set of individual heads (for box prediction or masks) into a single
+    head.
+    """
+
+    def __init__(self, cfg, heads):
+        super(CombinedROIHeads, self).__init__(heads)
+        self.cfg = cfg.clone()
+        if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR:
+            self.mask.feature_extractor = self.box.feature_extractor
+        if cfg.MODEL.KEYPOINT_ON and cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR:
+            self.keypoint.feature_extractor = self.box.feature_extractor
+
+    def forward(self, features, proposals, targets=None, language_dict_features=None, positive_map_label_to_token=None):
+        losses = {}
+        detections = proposals
+        if self.cfg.MODEL.BOX_ON:
+            # TODO rename x to roi_box_features, if it doesn't increase memory consumption
+            x, detections, loss_box = self.box(features, proposals, targets)
+            losses.update(loss_box)
+
+        if self.cfg.MODEL.MASK_ON:
+            mask_features = features
+            # optimization: during training, if we share the feature extractor between
+            # the box and the mask heads, then we can reuse the features already computed
+            if (
+                    self.training
+                    and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR
+            ):
+                mask_features = x
+            # During training, self.box() will return the unaltered proposals as "detections"
+            # this makes the API consistent during training and testing
+            x, detections, loss_mask = self.mask(
+                mask_features, detections, targets,
+                language_dict_features=language_dict_features,
+                positive_map_label_to_token=positive_map_label_to_token)
+            losses.update(loss_mask)
+
+        if self.cfg.MODEL.KEYPOINT_ON:
+            keypoint_features = features
+            # optimization: during training, if we share the feature extractor between
+            # the box and the mask heads, then we can reuse the features already computed
+            if (
+                    self.training
+                    and self.cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR
+            ):
+                keypoint_features = x
+            # During training, self.box() will return the unaltered proposals as "detections"
+            # this makes the API consistent during training and testing
+            x, detections, loss_keypoint = self.keypoint(keypoint_features, detections, targets)
+            losses.update(loss_keypoint)
+        return x, detections, losses
+
+
+def build_roi_heads(cfg):
+    # individually create the heads, that will be combined together
+    # afterwards
+    # if cfg.MODEL.RPN_ONLY:
+    #     return None
+
+    roi_heads = []
+    if cfg.MODEL.BOX_ON and not cfg.MODEL.RPN_ONLY:
+        roi_heads.append(("box", build_roi_box_head(cfg)))
+    if cfg.MODEL.MASK_ON:
+        roi_heads.append(("mask", build_roi_mask_head(cfg)))
+    if cfg.MODEL.KEYPOINT_ON:
+        roi_heads.append(("keypoint", build_roi_keypoint_head(cfg)))
+
+    # combine individual heads in a single module
+    if roi_heads:
+        roi_heads = CombinedROIHeads(cfg, roi_heads)
+    else:
+        roi_heads = None
+
+    return roi_heads
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d509ee6bb0c75d51960c192f782c4b2a8178a96e
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py
@@ -0,0 +1,75 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+
+from .roi_box_feature_extractors import make_roi_box_feature_extractor
+from .roi_box_predictors import make_roi_box_predictor
+from .inference import make_roi_box_post_processor
+from .loss import make_roi_box_loss_evaluator
+from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
+
+class ROIBoxHead(torch.nn.Module):
+    """
+    Generic Box Head class.
+    """
+
+    def __init__(self, cfg):
+        super(ROIBoxHead, self).__init__()
+        self.feature_extractor = make_roi_box_feature_extractor(cfg)
+        self.predictor = make_roi_box_predictor(cfg)
+        self.post_processor = make_roi_box_post_processor(cfg)
+        self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
+        self.onnx = cfg.MODEL.ONNX
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, features, proposals, targets=None):
+        """
+        Arguments:
+            features (list[Tensor]): feature-maps from possibly several levels
+            proposals (list[BoxList]): proposal boxes
+            targets (list[BoxList], optional): the ground-truth targets.
+
+        Returns:
+            x (Tensor): the result of the feature extractor
+            proposals (list[BoxList]): during training, the subsampled proposals
+                are returned. During testing, the predicted boxlists are returned
+            losses (dict[Tensor]): During training, returns the losses for the
+                head. During testing, returns an empty dict.
+        """
+
+        if self.training:
+            # Faster R-CNN subsamples during training the proposals with a fixed
+            # positive / negative ratio
+            with torch.no_grad():
+                proposals = self.loss_evaluator.subsample(proposals, targets)
+
+        # extract features that will be fed to the final classifier. The
+        # feature_extractor generally corresponds to the pooler + heads
+        x = self.feature_extractor(features, proposals)
+        # final classifier that converts the features into predictions
+        class_logits, box_regression = self.predictor(x)
+
+        if self.onnx:
+            return x, (class_logits, box_regression, [box.bbox for box in proposals]), {}
+
+        if not self.training:
+            result = self.post_processor((class_logits, box_regression), proposals)
+            return x, result, {}
+
+        loss_classifier, loss_box_reg = self.loss_evaluator(
+            [class_logits], [box_regression]
+        )
+        return (
+            x,
+            proposals,
+            dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg),
+        )
+
+
+def build_roi_box_head(cfg):
+    """
+    Constructs a new box head.
+    By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new class
+    and make it a parameter in the config
+    """
+    return ROIBoxHead(cfg)
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f64bb3060e22388caf57c2496c7eb6f7a4cb7f4
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py
@@ -0,0 +1,177 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
+
+class PostProcessor(nn.Module):
+    """
+    From a set of classification scores, box regression and proposals,
+    computes the post-processed boxes, and applies NMS to obtain the
+    final results
+    """
+
+    def __init__(
+        self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None
+    ):
+        """
+        Arguments:
+            score_thresh (float)
+            nms (float)
+            detections_per_img (int)
+            box_coder (BoxCoder)
+        """
+        super(PostProcessor, self).__init__()
+        self.score_thresh = score_thresh
+        self.nms = nms
+        self.detections_per_img = detections_per_img
+        if box_coder is None:
+            box_coder = BoxCoder(weights=(10., 10., 5., 5.))
+        self.box_coder = box_coder
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(self, x, boxes):
+        """
+        Arguments:
+            x (tuple[tensor, tensor]): x contains the class logits
+                and the box_regression from the model.
+            boxes (list[BoxList]): bounding boxes that are used as
+                reference, one for ech image
+
+        Returns:
+            results (list[BoxList]): one BoxList for each image, containing
+                the extra fields labels and scores
+        """
+        class_logits, box_regression = x
+        class_prob = F.softmax(class_logits, -1)
+
+        # TODO think about a representation of batch of boxes
+        image_shapes = [box.size for box in boxes]
+        boxes_per_image = [len(box) for box in boxes]
+        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)
+
+        extra_fields = [{} for box in boxes]
+        if boxes[0].has_field("cbox"):
+            concat_cboxes = torch.cat([a.get_field('cbox').bbox for a in boxes], dim=0)
+            concat_cscores = torch.cat([a.get_field('cbox').get_field('scores') for a in boxes], dim=0)
+            for cbox, cscore, extra_field in zip(concat_cboxes.split(boxes_per_image, dim=0),
+                                                 concat_cscores.split(boxes_per_image, dim=0),
+                                                 extra_fields):
+                extra_field["cbox"] = cbox
+                extra_field["cscore"] = cscore
+
+        proposals = self.box_coder.decode(
+            box_regression.view(sum(boxes_per_image), -1), concat_boxes
+        )
+
+        num_classes = class_prob.shape[1]
+
+        proposals = proposals.split(boxes_per_image, dim=0)
+        class_prob = class_prob.split(boxes_per_image, dim=0)
+
+        results = []
+        for prob, boxes_per_img, image_shape, extra_field in zip(
+            class_prob, proposals, image_shapes, extra_fields
+        ):
+            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, extra_field)
+            boxlist = boxlist.clip_to_image(remove_empty=False)
+            boxlist = self.filter_results(boxlist, num_classes)
+            results.append(boxlist)
+        return results
+
+    def prepare_boxlist(self, boxes, scores, image_shape, extra_field={}):
+        """
+        Returns BoxList from `boxes` and adds probability scores information
+        as an extra field
+        `boxes` has shape (#detections, 4 * #classes), where each row represents
+        a list of predicted bounding boxes for each of the object classes in the
+        dataset (including the background class). The detections in each row
+        originate from the same object proposal.
+        `scores` has shape (#detection, #classes), where each row represents a list
+        of object detection confidence scores for each of the object classes in the
+        dataset (including the background class). `scores[i, j]`` corresponds to the
+        box at `boxes[i, j * 4:(j + 1) * 4]`.
+        """
+        boxes = boxes.reshape(-1, 4)
+        scores = scores.reshape(-1)
+        boxlist = BoxList(boxes, image_shape, mode="xyxy")
+        boxlist.add_field("scores", scores)
+        for key, val in extra_field.items():
+            boxlist.add_field(key, val)
+        return boxlist
+
+    def filter_results(self, boxlist, num_classes):
+        """Returns bounding-box detection results by thresholding on scores and
+        applying non-maximum suppression (NMS).
+        """
+        # unwrap the boxlist to avoid additional overhead.
+        # if we had multi-class NMS, we could perform this directly on the boxlist
+        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
+        scores = boxlist.get_field("scores").reshape(-1, num_classes)
+        if boxlist.has_field('cbox'):
+            cboxes = boxlist.get_field("cbox").reshape(-1, 4)
+            cscores = boxlist.get_field("cscore")
+        else:
+            cboxes = None
+
+        device = scores.device
+        result = []
+        # Apply threshold on detection probabilities and apply NMS
+        # Skip j = 0, because it's the background class
+        inds_all = scores > self.score_thresh
+        for j in range(1, num_classes):
+            inds = inds_all[:, j].nonzero().squeeze(1)
+            scores_j = scores[inds, j]
+            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
+            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
+            boxlist_for_class.add_field("scores", scores_j)
+            if cboxes is not None:
+                cboxes_j = cboxes[inds, :]
+                cscores_j = cscores[inds]
+                cbox_boxlist = BoxList(cboxes_j, boxlist.size, mode="xyxy")
+                cbox_boxlist.add_field("scores", cscores_j)
+                boxlist_for_class.add_field("cbox", cbox_boxlist)
+
+            boxlist_for_class = boxlist_nms(
+                boxlist_for_class, self.nms, score_field="scores"
+            )
+            num_labels = len(boxlist_for_class)
+            boxlist_for_class.add_field(
+                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
+            )
+            result.append(boxlist_for_class)
+
+        result = cat_boxlist(result)
+        number_of_detections = len(result)
+
+        # Limit to max_per_image detections **over all classes**
+        if number_of_detections > self.detections_per_img > 0:
+            cls_scores = result.get_field("scores")
+            image_thresh, _ = torch.kthvalue(
+                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
+            )
+            keep = cls_scores >= image_thresh.item()
+            keep = torch.nonzero(keep).squeeze(1)
+            result = result[keep]
+        return result
+
+
+def make_roi_box_post_processor(cfg):
+    use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN
+
+    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
+    box_coder = BoxCoder(weights=bbox_reg_weights)
+
+    score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
+    nms_thresh = cfg.MODEL.ROI_HEADS.NMS
+    detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
+
+    postprocessor = PostProcessor(
+        score_thresh, nms_thresh, detections_per_img, box_coder
+    )
+    return postprocessor
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7592981fdf236c086d70b455967cf12ad0d275e
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
@@ -0,0 +1,187 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.layers import smooth_l1_loss
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+from maskrcnn_benchmark.modeling.matcher import Matcher
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import (
+    BalancedPositiveNegativeSampler
+)
+from maskrcnn_benchmark.modeling.utils import cat
+from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
+
+class FastRCNNLossComputation(object):
+    """
+    Computes the loss for Faster R-CNN.
+    Also supports FPN
+    """
+
+    def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
+        """
+        Arguments:
+            proposal_matcher (Matcher)
+            fg_bg_sampler (BalancedPositiveNegativeSampler)
+            box_coder (BoxCoder)
+        """
+        self.proposal_matcher = proposal_matcher
+        self.fg_bg_sampler = fg_bg_sampler
+        self.box_coder = box_coder
+
+    def match_targets_to_proposals(self, proposal, target):
+        match_quality_matrix = boxlist_iou(target, proposal)
+        matched_idxs = self.proposal_matcher(match_quality_matrix)
+        # Fast RCNN only need "labels" field for selecting the targets
+        target = target.copy_with_fields("labels")
+        # get the targets corresponding GT for each proposal
+        # NB: need to clamp the indices because we can have a single
+        # GT in the image, and matched_idxs can be -2, which goes
+        # out of bounds
+
+        if len(target):
+            matched_targets = target[matched_idxs.clamp(min=0)]
+        else:
+            device = target.get_field('labels').device
+            dtype = target.get_field('labels').dtype
+            labels = torch.zeros_like(matched_idxs, dtype=dtype, device=device)
+            matched_targets = target
+            matched_targets.add_field('labels', labels)
+
+        matched_targets.add_field("matched_idxs", matched_idxs)
+        return matched_targets
+
+    def prepare_targets(self, proposals, targets):
+        labels = []
+        regression_targets = []
+        for proposals_per_image, targets_per_image in zip(proposals, targets):
+            matched_targets = self.match_targets_to_proposals(
+                proposals_per_image, targets_per_image
+            )
+            matched_idxs = matched_targets.get_field("matched_idxs")
+
+            labels_per_image = matched_targets.get_field("labels")
+            labels_per_image = labels_per_image.to(dtype=torch.int64)
+
+            # Label background (below the low threshold)
+            bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
+            labels_per_image[bg_inds] = 0
+
+            # Label ignore proposals (between low and high thresholds)
+            ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS
+            labels_per_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            # compute regression targets
+            if not matched_targets.bbox.shape[0]:
+                zeros = torch.zeros_like(labels_per_image, dtype=torch.float32)
+                regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+            else:
+                regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, proposals_per_image.bbox)
+
+            labels.append(labels_per_image)
+            regression_targets.append(regression_targets_per_image)
+
+        return labels, regression_targets
+
+    def subsample(self, proposals, targets):
+        """
+        This method performs the positive/negative sampling, and return
+        the sampled proposals.
+        Note: this function keeps a state.
+
+        Arguments:
+            proposals (list[BoxList])
+            targets (list[BoxList])
+        """
+
+        labels, regression_targets = self.prepare_targets(proposals, targets)
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+
+        proposals = list(proposals)
+        # add corresponding label and regression_targets information to the bounding boxes
+        for labels_per_image, regression_targets_per_image, proposals_per_image in zip(
+            labels, regression_targets, proposals
+        ):
+            proposals_per_image.add_field("labels", labels_per_image)
+            proposals_per_image.add_field(
+                "regression_targets", regression_targets_per_image
+            )
+
+        # distributed sampled proposals, that were obtained on all feature maps
+        # concatenated via the fg_bg_sampler, into individual feature map levels
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
+            zip(sampled_pos_inds, sampled_neg_inds)
+        ):
+            img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
+            proposals_per_image = proposals[img_idx][img_sampled_inds]
+            proposals[img_idx] = proposals_per_image
+
+        self._proposals = proposals
+        return proposals
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def __call__(self, class_logits, box_regression):
+        """
+        Computes the loss for Faster R-CNN.
+        This requires that the subsample method has been called beforehand.
+
+        Arguments:
+            class_logits (list[Tensor])
+            box_regression (list[Tensor])
+
+        Returns:
+            classification_loss (Tensor)
+            box_loss (Tensor)
+        """
+
+        class_logits = cat(class_logits, dim=0)
+        box_regression = cat(box_regression, dim=0)
+        device = class_logits.device
+
+        if not hasattr(self, "_proposals"):
+            raise RuntimeError("subsample needs to be called before")
+
+        proposals = self._proposals
+
+        labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
+        regression_targets = cat(
+            [proposal.get_field("regression_targets") for proposal in proposals], dim=0
+        )
+
+        classification_loss = F.cross_entropy(class_logits, labels)
+
+        # get indices that correspond to the regression targets for
+        # the corresponding ground truth labels, to be used with
+        # advanced indexing
+        sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
+        labels_pos = labels[sampled_pos_inds_subset]
+        map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device)
+
+        box_loss = smooth_l1_loss(
+            box_regression[sampled_pos_inds_subset[:, None], map_inds],
+            regression_targets[sampled_pos_inds_subset],
+            size_average=False,
+            beta=1,
+        )
+        box_loss = box_loss / labels.numel()
+
+        return classification_loss, box_loss
+
+
+def make_roi_box_loss_evaluator(cfg):
+    matcher = Matcher(
+        cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,
+        cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=False,
+    )
+
+    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
+    box_coder = BoxCoder(weights=bbox_reg_weights)
+
+    fg_bg_sampler = BalancedPositiveNegativeSampler(
+        cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
+    )
+
+    loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder)
+
+    return loss_evaluator
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py
new file mode 100644
index 0000000000000000000000000000000000000000..8614c78d8f7a85874175f82eff042eb793e44c4b
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py
@@ -0,0 +1,201 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.modeling.backbone import resnet
+from maskrcnn_benchmark.modeling.poolers import Pooler
+from maskrcnn_benchmark.modeling.make_layers import group_norm
+from maskrcnn_benchmark.modeling.make_layers import make_fc
+
+
+
+@registry.ROI_BOX_FEATURE_EXTRACTORS.register("LightheadFeatureExtractor")
+class LightheadFeatureExtractor(nn.Module):
+    def __init__(self, cfg):
+        super(LightheadFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        input_size = 10 * resolution ** 2
+        representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
+        use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
+
+        C_in, C_mid, C_out = cfg.MODEL.BACKBONE.OUT_CHANNELS, 256, input_size
+        self.separable_conv_11 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0))
+        self.separable_conv_12 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7))
+        self.separable_conv_21 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0))
+        self.separable_conv_22 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7))
+
+        for module in [self.separable_conv_11, self.separable_conv_12, self.separable_conv_21, self.separable_conv_22]:
+            # Caffe2 implementation uses XavierFill, which in fact
+            # corresponds to kaiming_uniform_ in PyTorch
+            nn.init.kaiming_uniform_(module.weight, a=1)
+
+        self.pooler = pooler
+        self.fc6 = make_fc(input_size * resolution ** 2, representation_size, use_gn) #<TODO> wait official repo to support psroi
+
+
+    def forward(self, x, proposals):
+        light = []
+        for feat in x:
+            sc11 = self.separable_conv_11(feat)
+            sc12 = self.separable_conv_12(sc11)
+            sc21 = self.separable_conv_21(feat)
+            sc22 = self.separable_conv_22(sc21)
+            out = sc12+sc22
+            light.append(out)
+
+        x = self.pooler(light, proposals)
+        x = x.view(x.size(0), -1)
+        x = F.relu(self.fc6(x))
+
+        return x
+
+
+
+
+@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
+class ResNet50Conv5ROIFeatureExtractor(nn.Module):
+    def __init__(self, config):
+        super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
+
+        resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
+        scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES
+        sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+
+        stage = resnet.StageSpec(index=4, block_count=3, return_features=False)
+        head = resnet.ResNetHead(
+            block_module=config.MODEL.RESNETS.TRANS_FUNC,
+            stages=(stage,),
+            num_groups=config.MODEL.RESNETS.NUM_GROUPS,
+            width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP,
+            stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1,
+            stride_init=None,
+            res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS,
+            dilation=config.MODEL.RESNETS.RES5_DILATION
+        )
+
+        self.pooler = pooler
+        self.head = head
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+        x = self.head(x)
+        return x
+
+
+@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
+class FPN2MLPFeatureExtractor(nn.Module):
+    """
+    Heads for FPN for classification
+    """
+
+    def __init__(self, cfg):
+        super(FPN2MLPFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2
+        representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
+        use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
+        self.pooler = pooler
+        self.fc6 = make_fc(input_size, representation_size, use_gn)
+        self.fc7 = make_fc(representation_size, representation_size, use_gn)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+        x = x.view(x.size(0), -1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPNXconv1fcFeatureExtractor")
+class FPNXconv1fcFeatureExtractor(nn.Module):
+    """
+    Heads for FPN for classification
+    """
+
+    def __init__(self, cfg):
+        super(FPNXconv1fcFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        self.pooler = pooler
+        
+        use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM
+        num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS
+        dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION
+
+        xconvs = []
+        for ix in range(num_stacked_convs):
+            xconvs.append(
+                nn.Conv2d(
+                    in_channels,
+                    conv_head_dim,
+                    kernel_size=3,
+                    stride=1,
+                    padding=dilation,
+                    dilation=dilation,
+                    bias=False if use_gn else True
+                )
+            )
+            in_channels = conv_head_dim
+            if use_gn:
+                xconvs.append(group_norm(in_channels))
+            xconvs.append(nn.ReLU(inplace=True))
+
+        self.add_module("xconvs", nn.Sequential(*xconvs))
+        for modules in [self.xconvs,]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    if not use_gn:
+                        torch.nn.init.constant_(l.bias, 0)
+
+        input_size = conv_head_dim * resolution ** 2
+        representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
+        self.fc6 = make_fc(input_size, representation_size, use_gn=False)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+        x = self.xconvs(x)
+        x = x.view(x.size(0), -1)
+        x = F.relu(self.fc6(x))
+        return x
+
+
+def make_roi_box_feature_extractor(cfg):
+    func = registry.ROI_BOX_FEATURE_EXTRACTORS[
+        cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR
+    ]
+    return func(cfg)
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac03cfaece2e47900fc04b58e173f6dea6423caa
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py
@@ -0,0 +1,62 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from torch import nn
+
+
+class FastRCNNPredictor(nn.Module):
+    def __init__(self, config, pretrained=None):
+        super(FastRCNNPredictor, self).__init__()
+
+        stage_index = 4
+        stage2_relative_factor = 2 ** (stage_index - 1)
+        res2_out_channels = config.MODEL.RESNETS.RES2_OUT_CHANNELS
+        num_inputs = res2_out_channels * stage2_relative_factor
+
+        num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
+        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7)
+        self.cls_score = nn.Linear(num_inputs, num_classes)
+        self.bbox_pred = nn.Linear(num_inputs, num_classes * 4)
+
+        nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)
+        nn.init.constant_(self.cls_score.bias, 0)
+
+        nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001)
+        nn.init.constant_(self.bbox_pred.bias, 0)
+
+    def forward(self, x):
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        cls_logit = self.cls_score(x)
+        bbox_pred = self.bbox_pred(x)
+        return cls_logit, bbox_pred
+
+
+class FPNPredictor(nn.Module):
+    def __init__(self, cfg):
+        super(FPNPredictor, self).__init__()
+        num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
+        representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
+
+        self.cls_score = nn.Linear(representation_size, num_classes)
+        self.bbox_pred = nn.Linear(representation_size, num_classes * 4)
+
+        nn.init.normal_(self.cls_score.weight, std=0.01)
+        nn.init.normal_(self.bbox_pred.weight, std=0.001)
+        for l in [self.cls_score, self.bbox_pred]:
+            nn.init.constant_(l.bias, 0)
+
+    def forward(self, x):
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+
+_ROI_BOX_PREDICTOR = {
+    "FastRCNNPredictor": FastRCNNPredictor,
+    "FPNPredictor": FPNPredictor,
+}
+
+
+def make_roi_box_predictor(cfg):
+    func = _ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR]
+    return func(cfg)
diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ed960ff37cf1c68ac8831fdb87b82c91203ec2
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py
@@ -0,0 +1,121 @@
+import cv2
+import numpy as np
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
+
+
+class KeypointPostProcessor(nn.Module):
+    def __init__(self, keypointer=None):
+        super(KeypointPostProcessor, self).__init__()
+        self.keypointer = keypointer
+
+    def forward(self, x, boxes):
+        mask_prob = x
+
+        scores = None
+        if self.keypointer:
+            mask_prob, scores = self.keypointer(x, boxes)
+
+        assert len(boxes) == 1, "Only non-batched inference supported for now"
+        boxes_per_image = [box.bbox.size(0) for box in boxes]
+        mask_prob = mask_prob.split(boxes_per_image, dim=0)
+        scores = scores.split(boxes_per_image, dim=0)
+
+        results = []
+        for prob, box, score in zip(mask_prob, boxes, scores):
+            bbox = BoxList(box.bbox, box.size, mode="xyxy")
+            for field in box.fields():
+                bbox.add_field(field, box.get_field(field))
+            prob = PersonKeypoints(prob, box.size)
+            prob.add_field("logits", score)
+            bbox.add_field("keypoints", prob)
+            results.append(bbox)
+
+        return results
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = np.maximum(widths, 1)
+    heights = np.maximum(heights, 1)
+    widths_ceil = np.ceil(widths)
+    heights_ceil = np.ceil(heights)
+
+    # NCHW to NHWC for use with OpenCV
+    maps = np.transpose(maps, [0, 2, 3, 1])
+    min_size = 0  # cfg.KRCNN.INFERENCE_MIN_SIZE
+    num_keypoints = maps.shape[3]
+    xy_preds = np.zeros((len(rois), 3, num_keypoints), dtype=np.float32)
+    end_scores = np.zeros((len(rois), num_keypoints), dtype=np.float32)
+    for i in range(len(rois)):
+        if min_size > 0:
+            roi_map_width = int(np.maximum(widths_ceil[i], min_size))
+            roi_map_height = int(np.maximum(heights_ceil[i], min_size))
+        else:
+            roi_map_width = widths_ceil[i]
+            roi_map_height = heights_ceil[i]
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = cv2.resize(
+            maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC
+        )
+        # Bring back to CHW
+        roi_map = np.transpose(roi_map, [2, 0, 1])
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(axis=1)
+        x_int = pos % w
+        y_int = (pos - x_int) // w
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int + 0.5) * width_correction
+        y = (y_int + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[np.arange(num_keypoints), y_int, x_int]
+
+    return np.transpose(xy_preds, [0, 2, 1]), end_scores
+
+
+class Keypointer(object):
+    """
+    Projects a set of masks in an image on the locations
+    specified by the bounding boxes
+    """
+
+    def __init__(self, padding=0):
+        self.padding = padding
+
+    def __call__(self, masks, boxes):
+        # TODO do this properly
+        if isinstance(boxes, BoxList):
+            boxes = [boxes]
+        assert len(boxes) == 1
+
+        result, scores = heatmaps_to_keypoints(
+            masks.detach().cpu().numpy(), boxes[0].bbox.cpu().numpy()
+        )
+        return torch.from_numpy(result).to(masks.device), torch.as_tensor(scores, device=masks.device)
+
+
+def make_roi_keypoint_post_processor(cfg):
+    keypointer = Keypointer()
+    keypoint_post_processor = KeypointPostProcessor(keypointer)
+    return keypoint_post_processor
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1414782ab2a42bd1161c8496c434406df12619d6
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py
@@ -0,0 +1,50 @@
+import torch
+
+from .roi_keypoint_feature_extractors import make_roi_keypoint_feature_extractor
+from .roi_keypoint_predictors import make_roi_keypoint_predictor
+from .inference import make_roi_keypoint_post_processor
+from .loss import make_roi_keypoint_loss_evaluator
+
+
+class ROIKeypointHead(torch.nn.Module):
+    def __init__(self, cfg):
+        super(ROIKeypointHead, self).__init__()
+        self.cfg = cfg.clone()
+        self.feature_extractor = make_roi_keypoint_feature_extractor(cfg)
+        self.predictor = make_roi_keypoint_predictor(cfg)
+        self.post_processor = make_roi_keypoint_post_processor(cfg)
+        self.loss_evaluator = make_roi_keypoint_loss_evaluator(cfg)
+
+    def forward(self, features, proposals, targets=None):
+        """
+        Arguments:
+            features (list[Tensor]): feature-maps from possibly several levels
+            proposals (list[BoxList]): proposal boxes
+            targets (list[BoxList], optional): the ground-truth targets.
+
+        Returns:
+            x (Tensor): the result of the feature extractor
+            proposals (list[BoxList]): during training, the original proposals
+                are returned. During testing, the predicted boxlists are returned
+                with the `mask` field set
+            losses (dict[Tensor]): During training, returns the losses for the
+                head. During testing, returns an empty dict.
+        """
+        if self.training:
+            with torch.no_grad():
+                proposals = self.loss_evaluator.subsample(proposals, targets)
+
+        x = self.feature_extractor(features, proposals)
+        kp_logits = self.predictor(x)
+
+        if not self.training:
+            result = self.post_processor(kp_logits, proposals)
+            return x, result, {}
+
+        loss_kp = self.loss_evaluator(proposals, kp_logits)
+
+        return x, proposals, dict(loss_kp=loss_kp)
+
+
+def build_roi_keypoint_head(cfg):
+    return ROIKeypointHead(cfg)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..53716c80281f6e8e767552f061d91d486027831e
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py
@@ -0,0 +1,183 @@
+import torch
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.modeling.matcher import Matcher
+
+from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import (
+    BalancedPositiveNegativeSampler,
+)
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.modeling.utils import cat
+from maskrcnn_benchmark.layers import smooth_l1_loss
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+
+from maskrcnn_benchmark.structures.keypoint import keypoints_to_heat_map
+
+
+def project_keypoints_to_heatmap(keypoints, proposals, discretization_size):
+    proposals = proposals.convert("xyxy")
+    return keypoints_to_heat_map(
+        keypoints.keypoints, proposals.bbox, discretization_size
+    )
+
+
+def cat_boxlist_with_keypoints(boxlists):
+    assert all(boxlist.has_field("keypoints") for boxlist in boxlists)
+
+    kp = [boxlist.get_field("keypoints").keypoints for boxlist in boxlists]
+    kp = cat(kp, 0)
+
+    fields = boxlists[0].get_fields()
+    fields = [field for field in fields if field != "keypoints"]
+
+    boxlists = [boxlist.copy_with_fields(fields) for boxlist in boxlists]
+    boxlists = cat_boxlist(boxlists)
+    boxlists.add_field("keypoints", kp)
+    return boxlists
+
+
+def _within_box(points, boxes):
+    """Validate which keypoints are contained inside a given box.
+    points: NxKx2
+    boxes: Nx4
+    output: NxK
+    """
+    x_within = (points[..., 0] >= boxes[:, 0, None]) & (
+        points[..., 0] <= boxes[:, 2, None]
+    )
+    y_within = (points[..., 1] >= boxes[:, 1, None]) & (
+        points[..., 1] <= boxes[:, 3, None]
+    )
+    return x_within & y_within
+
+
+class KeypointRCNNLossComputation(object):
+    def __init__(self, proposal_matcher, fg_bg_sampler, discretization_size):
+        """
+        Arguments:
+            proposal_matcher (Matcher)
+            fg_bg_sampler (BalancedPositiveNegativeSampler)
+            discretization_size (int)
+        """
+        self.proposal_matcher = proposal_matcher
+        self.fg_bg_sampler = fg_bg_sampler
+        self.discretization_size = discretization_size
+
+    def match_targets_to_proposals(self, proposal, target):
+        match_quality_matrix = boxlist_iou(target, proposal)
+        matched_idxs = self.proposal_matcher(match_quality_matrix)
+        # Keypoint RCNN needs "labels" and "keypoints "fields for creating the targets
+        target = target.copy_with_fields(["labels", "keypoints"])
+        # get the targets corresponding GT for each proposal
+        # NB: need to clamp the indices because we can have a single
+        # GT in the image, and matched_idxs can be -2, which goes
+        # out of bounds
+        matched_targets = target[matched_idxs.clamp(min=0)]
+        matched_targets.add_field("matched_idxs", matched_idxs)
+        return matched_targets
+
+    def prepare_targets(self, proposals, targets):
+        labels = []
+        keypoints = []
+        for proposals_per_image, targets_per_image in zip(proposals, targets):
+            matched_targets = self.match_targets_to_proposals(
+                proposals_per_image, targets_per_image
+            )
+            matched_idxs = matched_targets.get_field("matched_idxs")
+
+            labels_per_image = matched_targets.get_field("labels")
+            labels_per_image = labels_per_image.to(dtype=torch.int64)
+
+            # this can probably be removed, but is left here for clarity
+            # and completeness
+            # TODO check if this is the right one, as BELOW_THRESHOLD
+            neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
+            labels_per_image[neg_inds] = 0
+
+            keypoints_per_image = matched_targets.get_field("keypoints")
+            within_box = _within_box(
+                keypoints_per_image.keypoints, matched_targets.bbox
+            )
+            vis_kp = keypoints_per_image.keypoints[..., 2] > 0
+            is_visible = (within_box & vis_kp).sum(1) > 0
+
+            labels_per_image[~is_visible] = -1
+
+            labels.append(labels_per_image)
+            keypoints.append(keypoints_per_image)
+
+        return labels, keypoints
+
+    def subsample(self, proposals, targets):
+        """
+        This method performs the positive/negative sampling, and return
+        the sampled proposals.
+        Note: this function keeps a state.
+
+        Arguments:
+            proposals (list[BoxList])
+            targets (list[BoxList])
+        """
+
+        labels, keypoints = self.prepare_targets(proposals, targets)
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+
+        proposals = list(proposals)
+        # add corresponding label and regression_targets information to the bounding boxes
+        for labels_per_image, keypoints_per_image, proposals_per_image in zip(
+            labels, keypoints, proposals
+        ):
+            proposals_per_image.add_field("labels", labels_per_image)
+            proposals_per_image.add_field("keypoints", keypoints_per_image)
+
+        # distributed sampled proposals, that were obtained on all feature maps
+        # concatenated via the fg_bg_sampler, into individual feature map levels
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
+            zip(sampled_pos_inds, sampled_neg_inds)
+        ):
+            img_sampled_inds = torch.nonzero(pos_inds_img).squeeze(1)
+            proposals_per_image = proposals[img_idx][img_sampled_inds]
+            proposals[img_idx] = proposals_per_image
+
+        self._proposals = proposals
+        return proposals
+
+    def __call__(self, proposals, keypoint_logits):
+        heatmaps = []
+        valid = []
+        for proposals_per_image in proposals:
+            kp = proposals_per_image.get_field("keypoints")
+            heatmaps_per_image, valid_per_image = project_keypoints_to_heatmap(
+                kp, proposals_per_image, self.discretization_size
+            )
+            heatmaps.append(heatmaps_per_image.view(-1))
+            valid.append(valid_per_image.view(-1))
+
+        keypoint_targets = cat(heatmaps, dim=0)
+        valid = cat(valid, dim=0).to(dtype=torch.bool)
+        valid = torch.nonzero(valid).squeeze(1)
+
+        # torch.mean (in binary_cross_entropy_with_logits) does'nt
+        # accept empty tensors, so handle it sepaartely
+        if keypoint_targets.numel() == 0 or len(valid) == 0:
+            return keypoint_logits.sum() * 0
+
+        N, K, H, W = keypoint_logits.shape
+        keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+        keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+        return keypoint_loss
+
+
+def make_roi_keypoint_loss_evaluator(cfg):
+    matcher = Matcher(
+        cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,
+        cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=False,
+    )
+    fg_bg_sampler = BalancedPositiveNegativeSampler(
+        cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
+    )
+    resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION
+    loss_evaluator = KeypointRCNNLossComputation(matcher, fg_bg_sampler, resolution)
+    return loss_evaluator
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b4b90be3efebf777871399b7dca821fec60a45
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py
@@ -0,0 +1,96 @@
+from torch import nn
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.modeling.poolers import Pooler
+
+from maskrcnn_benchmark.layers import Conv2d
+from maskrcnn_benchmark.layers import ConvTranspose2d
+
+
+class KeypointRCNNFeatureExtractor(nn.Module):
+    def __init__(self, cfg):
+        super(KeypointRCNNFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        self.pooler = pooler
+
+        input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS
+        next_feature = input_features
+        self.blocks = []
+        for layer_idx, layer_features in enumerate(layers, 1):
+            layer_name = "conv_fcn{}".format(layer_idx)
+            module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1)
+            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+            nn.init.constant_(module.bias, 0)
+            self.add_module(layer_name, module)
+            next_feature = layer_features
+            self.blocks.append(layer_name)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+        for layer_name in self.blocks:
+            x = F.relu(getattr(self, layer_name)(x))
+        return x
+
+class KeypointRCNNFeature2XZoomExtractor(nn.Module):
+    def __init__(self, cfg):
+        super(KeypointRCNNFeature2XZoomExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        self.pooler = pooler
+
+        input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS
+        next_feature = input_features
+        self.blocks = []
+        for layer_idx, layer_features in enumerate(layers, 1):
+            layer_name = "conv_fcn{}".format(layer_idx)
+            module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1)
+            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+            nn.init.constant_(module.bias, 0)
+            self.add_module(layer_name, module)
+            if layer_idx==len(layers)//2:
+                deconv_kernel = 4
+                kps_upsacle = ConvTranspose2d(layer_features, layer_features, deconv_kernel,
+                                              stride=2, padding=deconv_kernel//2-1)
+                nn.init.kaiming_normal_(kps_upsacle.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(kps_upsacle.bias, 0)
+                self.add_module("conv_fcn_upscale", kps_upsacle)
+                self.blocks.append("conv_fcn_upscale")
+
+            next_feature = layer_features
+            self.blocks.append(layer_name)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+        for layer_name in self.blocks:
+            x = F.relu(getattr(self, layer_name)(x))
+        return x
+
+
+_ROI_KEYPOINT_FEATURE_EXTRACTORS = {
+    "KeypointRCNNFeatureExtractor": KeypointRCNNFeatureExtractor,
+    "KeypointRCNNFeature2XZoomExtractor": KeypointRCNNFeature2XZoomExtractor
+}
+
+
+def make_roi_keypoint_feature_extractor(cfg):
+    func = _ROI_KEYPOINT_FEATURE_EXTRACTORS[
+        cfg.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR
+    ]
+    return func(cfg)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff8ec3849695737580d5b2da2b411c489216a1a
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py
@@ -0,0 +1,39 @@
+from torch import nn
+from torch.nn import functional as F
+
+from maskrcnn_benchmark import layers
+
+
+class KeypointRCNNPredictor(nn.Module):
+    def __init__(self, cfg):
+        super(KeypointRCNNPredictor, self).__init__()
+        input_features = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS[-1]
+        num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES
+        deconv_kernel = 4
+        self.kps_score_lowres = layers.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(
+            self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
+        )
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        x = layers.interpolate(
+            x, scale_factor=self.up_scale, mode="bilinear", align_corners=False
+        )
+        return x
+
+
+_ROI_KEYPOINT_PREDICTOR = {"KeypointRCNNPredictor": KeypointRCNNPredictor}
+
+
+def make_roi_keypoint_predictor(cfg):
+    func = _ROI_KEYPOINT_PREDICTOR[cfg.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR]
+    return func(cfg)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e81b6697536ff23f8b88f7ea1d89da8d8c28e1
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py
@@ -0,0 +1,65 @@
+from torch import nn
+
+from maskrcnn_benchmark.modeling.make_layers import make_conv3x3
+
+
+class Residual(nn.Module):
+    def __init__(self, inp_dim, out_dim, use_gn=False):
+        super(Residual, self).__init__()
+        self.relu = nn.ReLU()
+        # self.bn1 = nn.BatchNorm2d(inp_dim)
+        self.conv1 = make_conv3x3(inp_dim, int(out_dim / 2), 1, use_relu=False, use_gn=use_gn)
+        # self.bn2 = nn.BatchNorm2d(int(out_dim / 2))
+        self.conv2 = make_conv3x3(int(out_dim / 2), int(out_dim / 2), 3, use_relu=False, use_gn=use_gn)
+        # self.bn3 = nn.BatchNorm2d(int(out_dim / 2))
+        self.conv3 = make_conv3x3(int(out_dim / 2), out_dim, 1, use_relu=False, use_gn=use_gn)
+        if inp_dim == out_dim:
+            self.need_skip = False
+        else:
+            self.need_skip = True
+            self.skip_layer = make_conv3x3(inp_dim, out_dim, 1, use_relu=False, use_gn=False)
+
+    def forward(self, x):
+        if self.need_skip:
+            residual = self.skip_layer(x)
+        else:
+            residual = x
+        out = x
+        # out = self.bn1(out)
+        out = self.relu(out)
+        out = self.conv1(out)
+        # out = self.bn2(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+        # out = self.bn3(out)
+        out = self.relu(out)
+        out = self.conv3(out)
+        out += residual
+        return out
+
+
+class Hourglass(nn.Module):
+    def __init__(self, n, f, gn=False, increase=0):
+        super(Hourglass, self).__init__()
+        nf = f + increase
+        self.up1 = Residual(f, f)
+        # Lower branch
+        self.pool1 = nn.MaxPool2d(2, 2)
+        self.low1 = Residual(f, nf)
+        self.n = n
+        # Recursive hourglass
+        if self.n > 1:
+            self.low2 = Hourglass(n-1, nf, gn=gn)
+        else:
+            self.low2 = Residual(nf, nf, gn)
+        self.low3 = Residual(nf, f, gn)
+        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
+
+    def forward(self, x):
+        up1 = self.up1(x)
+        pool1 = self.pool1(x)
+        low1 = self.low1(pool1)
+        low2 = self.low2(low1)
+        low3 = self.low3(low2)
+        up2 = self.up2(low3)
+        return up1 + up2
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c02b78101a6d33db1917377080265206548c7cf
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
@@ -0,0 +1,224 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+
+def convert_mask_grounding_to_od_logits(logits, positive_map_label_to_token, num_classes):
+    od_logits = torch.zeros(logits.shape[0], num_classes + 1, logits.shape[2], logits.shape[3]).to(logits.device)
+    for label_j in positive_map_label_to_token:
+        od_logits[:, label_j, :, :] = logits[:, torch.LongTensor(positive_map_label_to_token[label_j]), :, :].mean(1)
+    mask_prob = od_logits.sigmoid()
+    return mask_prob
+
+
+# TODO check if want to return a single BoxList or a composite
+# object
+class MaskPostProcessor(nn.Module):
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    If a masker object is passed, it will additionally
+    project the masks in the image according to the locations in boxes,
+    """
+
+    def __init__(self, masker=None, mdetr_style_aggregate_class_num=None, vl_version=None):
+        super(MaskPostProcessor, self).__init__()
+        self.masker = masker
+        self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num
+        self.vl_version = vl_version
+
+    def forward(self, x, boxes, positive_map_label_to_token=None):
+        """
+        Arguments:
+            x (Tensor): the mask logits
+            boxes (list[BoxList]): bounding boxes that are used as
+                reference, one for ech image
+
+        Returns:
+            results (list[BoxList]): one BoxList for each image, containing
+                the extra field mask
+        """
+        if self.vl_version:
+            mask_prob = convert_mask_grounding_to_od_logits(x, positive_map_label_to_token, self.mdetr_style_aggregate_class_num)
+        else:
+            mask_prob = x.sigmoid()
+
+        # select masks coresponding to the predicted classes
+        num_masks = x.shape[0]
+        labels = [bbox.get_field("labels") for bbox in boxes]
+        labels = torch.cat(labels)
+        if not self.vl_version:
+            # TODO: a hack for binary mask head
+            labels = (labels > 0).to(dtype=torch.int64)
+
+        index = torch.arange(num_masks, device=labels.device)
+        mask_prob = mask_prob[index, labels][:, None]
+
+        boxes_per_image = [len(box) for box in boxes]
+        mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+        if self.masker:
+            mask_prob = self.masker(mask_prob, boxes)
+
+        results = []
+        for prob, box in zip(mask_prob, boxes):
+            bbox = BoxList(box.bbox, box.size, mode="xyxy")
+            for field in box.fields():
+                bbox.add_field(field, box.get_field(field))
+            bbox.add_field("mask", prob)
+            results.append(bbox)
+
+        return results
+
+
+class MaskPostProcessorCOCOFormat(MaskPostProcessor):
+    """
+    From the results of the CNN, post process the results
+    so that the masks are pasted in the image, and
+    additionally convert the results to COCO format.
+    """
+
+    def forward(self, x, boxes, positive_map_label_to_token=None, vl_version=None):
+        import pycocotools.mask as mask_util
+        import numpy as np
+
+        results = super(MaskPostProcessorCOCOFormat, self).forward(x, boxes)
+        for result in results:
+            masks = result.get_field("mask").cpu()
+            rles = [
+                mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
+                for mask in masks
+            ]
+            for rle in rles:
+                rle["counts"] = rle["counts"].decode("utf-8")
+            result.add_field("mask", rles)
+        return results
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily gor paste_mask_in_image
+def expand_boxes(boxes, scale):
+    w_half = (boxes[:, 2] - boxes[:, 0]) * .5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * .5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * .5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * .5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+def expand_masks(mask, padding):
+    N = mask.shape[0]
+    M = mask.shape[-1]
+    pad2 = 2 * padding
+    scale = float(M + pad2) / M
+    padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
+    padded_mask[:, :, padding:-padding, padding:-padding] = mask
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
+    padded_mask, scale = expand_masks(mask[None], padding=padding)
+    mask = padded_mask[0, 0]
+    box = expand_boxes(box[None], scale)[0]
+    box = box.to(dtype=torch.int32)
+
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = mask.to(torch.float32)
+    mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
+    mask = mask[0][0]
+
+    if thresh >= 0:
+        mask = mask > thresh
+    else:
+        # for visualization and debugging, we also
+        # allow it to return an unmodified mask
+        mask = (mask * 255).to(torch.bool)
+
+    im_mask = torch.zeros((im_h, im_w), dtype=torch.bool)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[
+        (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
+    ]
+    return im_mask
+
+
+class Masker(object):
+    """
+    Projects a set of masks in an image on the locations
+    specified by the bounding boxes
+    """
+
+    def __init__(self, threshold=0.5, padding=1):
+        self.threshold = threshold
+        self.padding = padding
+
+    def forward_single_image(self, masks, boxes):
+        boxes = boxes.convert("xyxy")
+        im_w, im_h = boxes.size
+        res = [
+            paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
+            for mask, box in zip(masks, boxes.bbox)
+        ]
+        if len(res) > 0:
+            res = torch.stack(res, dim=0)[:, None]
+        else:
+            res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
+        return res
+
+    def __call__(self, masks, boxes):
+        if isinstance(boxes, BoxList):
+            boxes = [boxes]
+
+        # Make some sanity check
+        assert len(boxes) == len(masks), "Masks and boxes should have the same length."
+
+        # TODO:  Is this JIT compatible?
+        # If not we should make it compatible.
+        results = []
+        for mask, box in zip(masks, boxes):
+            assert mask.shape[0] == len(box), "Number of objects should be the same."
+            result = self.forward_single_image(mask, box)
+            results.append(result)
+        return results
+
+
+def make_roi_mask_post_processor(cfg):
+    if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS:
+        mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD
+        masker = Masker(threshold=mask_threshold, padding=1)
+    else:
+        masker = None
+    mdetr_style_aggregate_class_num = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM
+    mask_post_processor = MaskPostProcessor(masker,
+                                            mdetr_style_aggregate_class_num,
+                                            vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"))
+    return mask_post_processor
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..22edb57fd8b67e370f819ba5d8a2a37df8c2a6f7
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py
@@ -0,0 +1,179 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.layers import smooth_l1_loss
+from maskrcnn_benchmark.modeling.matcher import Matcher
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.modeling.utils import cat
+
+
+def project_masks_on_boxes(segmentation_masks, proposals, discretization_size):
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+
+    Arguments:
+        segmentation_masks: an instance of SegmentationMask
+        proposals: an instance of BoxList
+    """
+    masks = []
+    M = discretization_size
+    device = proposals.bbox.device
+    proposals = proposals.convert("xyxy")
+    assert segmentation_masks.size == proposals.size, "{}, {}".format(
+        segmentation_masks, proposals
+    )
+    # TODO put the proposals on the CPU, as the representation for the
+    # masks is not efficient GPU-wise (possibly several small tensors for
+    # representing a single instance mask)
+    proposals = proposals.bbox.to(torch.device("cpu"))
+    for segmentation_mask, proposal in zip(segmentation_masks, proposals):
+        # crop the masks, resize them to the desired resolution and
+        # then convert them to the tensor representation,
+        # instead of the list representation that was used
+        cropped_mask = segmentation_mask.crop(proposal)
+        scaled_mask = cropped_mask.resize((M, M))
+        mask = scaled_mask.convert(mode="mask")
+        masks.append(mask)
+    if len(masks) == 0:
+        return torch.empty(0, dtype=torch.float32, device=device)
+    return torch.stack(masks, dim=0).to(device, dtype=torch.float32)
+
+
+class MaskRCNNLossComputation(object):
+    def __init__(self, proposal_matcher, discretization_size, vl_version=False):
+        """
+        Arguments:
+            proposal_matcher (Matcher)
+            discretization_size (int)
+        """
+        self.proposal_matcher = proposal_matcher
+        self.discretization_size = discretization_size
+        self.vl_version = vl_version
+
+    def match_targets_to_proposals(self, proposal, target):
+        match_quality_matrix = boxlist_iou(target, proposal)
+        matched_idxs = self.proposal_matcher(match_quality_matrix)
+        # Mask RCNN needs "labels" and "masks "fields for creating the targets
+        if self.vl_version:
+            target = target.copy_with_fields(["positive_map", "masks"])
+        else:
+            target = target.copy_with_fields(["labels", "masks"])
+        # get the targets corresponding GT for each proposal
+        # NB: need to clamp the indices because we can have a single
+        # GT in the image, and matched_idxs can be -2, which goes
+        # out of bounds
+        matched_targets = target[matched_idxs.clamp(min=0)]
+        matched_targets.add_field("matched_idxs", matched_idxs)
+        return matched_targets
+
+    def prepare_targets(self, proposals, targets):
+        labels = []
+        masks = []
+        positive_maps = []
+        for proposals_per_image, targets_per_image in zip(proposals, targets):
+            matched_targets = self.match_targets_to_proposals(
+                proposals_per_image, targets_per_image
+            )
+            matched_idxs = matched_targets.get_field("matched_idxs")
+
+            if self.vl_version:
+                positive_maps_per_image = matched_targets.get_field("positive_map")
+
+                # this can probably be removed, but is left here for clarity
+                # and completeness
+                neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
+                positive_maps_per_image[neg_inds, :] = 0
+
+                positive_maps.append(positive_maps_per_image)
+
+                # TODO: make sure for the softmax [NoObj] case
+                labels_per_image = positive_maps_per_image.sum(dim=-1)
+                labels_per_image = labels_per_image.to(dtype=torch.int64)
+            else:
+                labels_per_image = matched_targets.get_field("labels")
+                labels_per_image = labels_per_image.to(dtype=torch.int64)
+
+                # this can probably be removed, but is left here for clarity
+                # and completeness
+                neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
+                labels_per_image[neg_inds] = 0
+
+            # mask scores are only computed on positive samples
+            positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1)
+
+            segmentation_masks = matched_targets.get_field("masks")
+            segmentation_masks = segmentation_masks[positive_inds]
+
+            positive_proposals = proposals_per_image[positive_inds]
+
+            masks_per_image = project_masks_on_boxes(
+                segmentation_masks, positive_proposals, self.discretization_size
+            )
+
+            labels.append(labels_per_image)
+            masks.append(masks_per_image)
+
+        return labels, masks, positive_maps
+
+    def __call__(self, proposals, mask_logits, targets):
+        """
+        Arguments:
+            proposals (list[BoxList])
+            mask_logits (Tensor)
+            targets (list[BoxList])
+
+        Return:
+            mask_loss (Tensor): scalar tensor containing the loss
+        """
+        labels, mask_targets, positive_maps = self.prepare_targets(proposals, targets)
+
+        labels = cat(labels, dim=0)
+        mask_targets = cat(mask_targets, dim=0)
+
+        positive_inds = torch.nonzero(labels > 0).squeeze(1)
+        labels_pos = labels[positive_inds]
+        # TODO: a hack for binary mask head
+        labels_pos = (labels_pos > 0).to(dtype=torch.int64)
+
+        # torch.mean (in binary_cross_entropy_with_logits) doesn't
+        # accept empty tensors, so handle it separately
+        if mask_targets.numel() == 0:
+            return mask_logits.sum() * 0
+
+        if self.vl_version:
+            positive_maps = cat(positive_maps, dim=0)
+            mask_logits_pos = []
+            for positive_ind in positive_inds:
+                positive_map = positive_maps[positive_ind]
+                # TODO: make sure for the softmax [NoObj] case
+                mask_logit_pos = mask_logits[positive_ind][torch.nonzero(positive_map).squeeze(1)].mean(dim=0, keepdim=True)
+                mask_logits_pos.append(mask_logit_pos)
+            mask_logits_pos = cat(mask_logits_pos, dim=0)
+            mask_loss = F.binary_cross_entropy_with_logits(
+                mask_logits_pos, mask_targets
+            )
+        else:
+            mask_loss = F.binary_cross_entropy_with_logits(
+                mask_logits[positive_inds, labels_pos], mask_targets
+            )
+        return mask_loss
+
+
+def make_roi_mask_loss_evaluator(cfg):
+    matcher = Matcher(
+        cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,
+        cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=False,
+    )
+
+    loss_evaluator = MaskRCNNLossComputation(
+        matcher, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION,
+        vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL")
+    )
+
+    return loss_evaluator
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8fb640becf3a455c1eaaf72ad72b5c079a5fe
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py
@@ -0,0 +1,88 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+from .roi_mask_feature_extractors import make_roi_mask_feature_extractor
+from .roi_mask_predictors import make_roi_mask_predictor
+from .inference import make_roi_mask_post_processor
+from .loss import make_roi_mask_loss_evaluator
+
+
+def keep_only_positive_boxes(boxes):
+    """
+    Given a set of BoxList containing the `labels` field,
+    return a set of BoxList for which `labels > 0`.
+
+    Arguments:
+        boxes (list of BoxList)
+    """
+    assert isinstance(boxes, (list, tuple))
+    assert isinstance(boxes[0], BoxList)
+    assert boxes[0].has_field("labels")
+    positive_boxes = []
+    positive_inds = []
+    num_boxes = 0
+    for boxes_per_image in boxes:
+        labels = boxes_per_image.get_field("labels")
+        inds_mask = labels > 0
+        inds = inds_mask.nonzero().squeeze(1)
+        positive_boxes.append(boxes_per_image[inds])
+        positive_inds.append(inds_mask)
+    return positive_boxes, positive_inds
+
+
+class ROIMaskHead(torch.nn.Module):
+    def __init__(self, cfg):
+        super(ROIMaskHead, self).__init__()
+        self.cfg = cfg.clone()
+        self.feature_extractor = make_roi_mask_feature_extractor(cfg)
+        self.predictor = make_roi_mask_predictor(cfg)
+        self.post_processor = make_roi_mask_post_processor(cfg)
+        self.loss_evaluator = make_roi_mask_loss_evaluator(cfg)
+
+    def forward(self, features, proposals, targets=None,
+                language_dict_features=None,
+                positive_map_label_to_token=None
+                ):
+        """
+        Arguments:
+            features (list[Tensor]): feature-maps from possibly several levels
+            proposals (list[BoxList]): proposal boxes
+            targets (list[BoxList], optional): the ground-truth targets.
+            language_dict_features: language features: hidden, embedding, mask, ...
+
+        Returns:
+            x (Tensor): the result of the feature extractor
+            proposals (list[BoxList]): during training, the original proposals
+                are returned. During testing, the predicted boxlists are returned
+                with the `mask` field set
+            losses (dict[Tensor]): During training, returns the losses for the
+                head. During testing, returns an empty dict.
+        """
+        if self.training:
+            # during training, only focus on positive boxes
+            all_proposals = proposals
+            proposals, positive_inds = keep_only_positive_boxes(proposals)
+        if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR:
+            x = features
+            x = x[torch.cat(positive_inds, dim=0)]
+        else:
+            x = self.feature_extractor(features, proposals)
+        if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"):
+            mask_logits = self.predictor(x, language_dict_features)
+        else:
+            mask_logits = self.predictor(x)
+
+        if not self.training:
+            result = self.post_processor(mask_logits, proposals, positive_map_label_to_token)
+            return x, result, {}
+
+        loss_mask = self.loss_evaluator(proposals, mask_logits, targets)
+
+        return x, all_proposals, dict(loss_mask=loss_mask)
+
+
+def build_roi_mask_head(cfg):
+    return ROIMaskHead(cfg)
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py
new file mode 100644
index 0000000000000000000000000000000000000000..c891feb22703e5d47be37ec20189c4c2bbd7c14c
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py
@@ -0,0 +1,117 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from torch import nn
+from torch.nn import functional as F
+
+from .hourglass import Hourglass
+from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor
+from maskrcnn_benchmark.modeling.poolers import Pooler
+from maskrcnn_benchmark.layers import Conv2d
+from maskrcnn_benchmark.modeling.make_layers import make_conv3x3
+
+
+
+class MaskRCNNFPNFeatureExtractor(nn.Module):
+    """
+    Heads for FPN for classification
+    """
+
+    def __init__(self, cfg):
+        """
+        Arguments:
+            num_classes (int): number of output classes
+            input_size (int): number of channels of the input once it's flattened
+            representation_size (int): size of the intermediate representation
+        """
+        super(MaskRCNNFPNFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        self.pooler = pooler
+
+        use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN
+        layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS
+        dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION
+
+        next_feature = input_size
+        self.blocks = []
+        for layer_idx, layer_features in enumerate(layers, 1):
+            layer_name = "mask_fcn{}".format(layer_idx)
+            module = make_conv3x3(next_feature, layer_features, 
+                dilation=dilation, stride=1, use_gn=use_gn
+            )
+            self.add_module(layer_name, module)
+            next_feature = layer_features
+            self.blocks.append(layer_name)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+
+        for layer_name in self.blocks:
+            x = F.relu(getattr(self, layer_name)(x))
+
+        return x
+
+
+class HourglassFPNFeatureExtractor(nn.Module):
+    """
+    Heads for FPN for classification
+    """
+
+    def __init__(self, cfg):
+        """
+        Arguments:
+            num_classes (int): number of output classes
+            input_size (int): number of channels of the input once it's flattened
+            representation_size (int): size of the intermediate representation
+        """
+        super(HourglassFPNFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
+        pooler = Pooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio,
+        )
+        input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        self.pooler = pooler
+
+        use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN
+        layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS
+        scale = cfg.MODEL.ROI_MASK_HEAD.HG_SCALE
+
+        assert input_size==layers[0]
+        self.blocks = []
+        for layer_idx, layer_features in enumerate(layers, 1):
+            layer_name = "mask_hg{}".format(layer_idx)
+            module = Hourglass(scale, layer_features, gn=use_gn)
+            self.add_module(layer_name, module)
+            self.blocks.append(layer_name)
+
+    def forward(self, x, proposals):
+        x = self.pooler(x, proposals)
+
+        for layer_name in self.blocks:
+            x = F.relu(getattr(self, layer_name)(x))
+
+        return x
+
+
+_ROI_MASK_FEATURE_EXTRACTORS = {
+    "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor,
+    "MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor,
+    "HourglassFPNFeatureExtractor": HourglassFPNFeatureExtractor,
+}
+
+
+def make_roi_mask_feature_extractor(cfg):
+    func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR]
+    return func(cfg)
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c7e7ff2fb28e3ee39750cbbec39a46539f0c455
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.layers import Conv2d, _NewEmptyTensorOp
+from maskrcnn_benchmark.layers import ConvTranspose2d
+from ...utils import permute_and_flatten
+
+
+class MaskRCNNC4Predictor(nn.Module):
+    def __init__(self, cfg):
+        super(MaskRCNNC4Predictor, self).__init__()
+        # TODO: a hack for binary mask head
+        # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
+        num_classes = 2
+        dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1]
+
+        if cfg.MODEL.ROI_HEADS.USE_FPN:
+            num_inputs = dim_reduced
+        else:
+            stage_index = 4
+            stage2_relative_factor = 2 ** (stage_index - 1)
+            res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
+            num_inputs = res2_out_channels * stage2_relative_factor
+
+        self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
+        self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0)
+
+        for name, param in self.named_parameters():
+            if "bias" in name:
+                nn.init.constant_(param, 0)
+            elif "weight" in name:
+                # Caffe2 implementation uses MSRAFill, which in fact
+                # corresponds to kaiming_normal_ in PyTorch
+                nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+
+    def forward(self, x):
+        x = F.relu(self.conv5_mask(x))
+        return self.mask_fcn_logits(x)
+
+
+class VLMaskRCNNC4Predictor(nn.Module):
+    def __init__(self, cfg):
+        super(VLMaskRCNNC4Predictor, self).__init__()
+        dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1]
+
+        if cfg.MODEL.ROI_HEADS.USE_FPN:
+            num_inputs = dim_reduced
+        else:
+            stage_index = 4
+            stage2_relative_factor = 2 ** (stage_index - 1)
+            res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
+            num_inputs = res2_out_channels * stage2_relative_factor
+
+        self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
+
+        # self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0)
+        log_scale = cfg.MODEL.DYHEAD.LOG_SCALE
+        self.out_dim = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
+        self.dot_product_projection_image = nn.Identity()
+        self.dot_product_projection_text = nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM,
+                                                     dim_reduced, bias=True)
+        self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
+        self.bias_lang = nn.Parameter(torch.zeros(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True)
+
+        for name, param in self.named_parameters():
+            if "bias" in name:
+                nn.init.constant_(param, 0)
+            elif "weight" in name:
+                # Caffe2 implementation uses MSRAFill, which in fact
+                # corresponds to kaiming_normal_ in PyTorch
+                nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+
+    def forward(self, x, language_dict_features):
+        x = F.relu(self.conv5_mask(x))
+        if x.numel() <= 0:
+            output_shape = [x.shape[0], self.out_dim] + x.shape[-2:]
+            return _NewEmptyTensorOp.apply(x, output_shape)
+
+        embedding = language_dict_features["hidden"]
+        # norm
+        embedding = F.normalize(embedding, p=2, dim=-1)
+        dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0)
+        dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang)
+
+        B, C, H, W = x.shape
+        # add bias (language)
+        dot_product_proj_queries = self.dot_product_projection_image(x)
+        dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W)
+        A = dot_product_proj_queries.shape[1]
+        bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1)
+
+        # dot product
+        dot_product_logit = (torch.matmul(dot_product_proj_queries,
+                                          dot_product_proj_tokens.transpose(-1,
+                                                                            -2)) / self.log_scale.exp()) + bias
+        # clamp for stability
+        dot_product_logit = torch.clamp(dot_product_logit, max=50000)
+        dot_product_logit = torch.clamp(dot_product_logit, min=-50000)
+        dot_product_logit = dot_product_logit.view(B, H, W, self.out_dim).permute(0, 3, 1, 2)
+        return dot_product_logit
+
+
+_ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor,
+                       "VLMaskRCNNC4Predictor": VLMaskRCNNC4Predictor}
+
+
+def make_roi_mask_predictor(cfg):
+    func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR]
+    return func(cfg)
diff --git a/maskrcnn_benchmark/modeling/rpn/__init__.py b/maskrcnn_benchmark/modeling/rpn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e6ed6932c3011f357a1c9a97fce632ae9e6eb3
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# from .rpn import build_rpn
+from .rpn import RPNModule
+from .retina import RetinaNetModule
+from .fcos import FCOSModule
+from .atss import ATSSModule
+from .dyhead import DyHeadModule
+from .vldyhead import VLDyHeadModule
+
+_RPN_META_ARCHITECTURES = {"RPN": RPNModule,
+                           "RETINA": RetinaNetModule,
+                           "FCOS": FCOSModule,
+                           "ATSS": ATSSModule,
+                           "DYHEAD": DyHeadModule,
+                           "VLDYHEAD": VLDyHeadModule
+                           }
+
+
+def build_rpn(cfg):
+    """
+    This gives the gist of it. Not super important because it doesn't change as much
+    """
+    rpn_arch = _RPN_META_ARCHITECTURES[cfg.MODEL.RPN_ARCHITECTURE]
+    return rpn_arch(cfg)
diff --git a/maskrcnn_benchmark/modeling/rpn/anchor_generator.py b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c396730280e0f8f549872ad0403de36a2a626321
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py
@@ -0,0 +1,425 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.image_list import ImageList
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+
+class BufferList(nn.Module):
+    """
+    Similar to nn.ParameterList, but for buffers
+    """
+
+    def __init__(self, buffers=None):
+        super(BufferList, self).__init__()
+        if buffers is not None:
+            self.extend(buffers)
+
+    def extend(self, buffers):
+        offset = len(self)
+        for i, buffer in enumerate(buffers):
+            self.register_buffer(str(offset + i), buffer)
+        return self
+
+    def __len__(self):
+        return len(self._buffers)
+
+    def __iter__(self):
+        return iter(self._buffers.values())
+
+
+class AnchorGenerator(nn.Module):
+    """
+    For a set of image sizes and feature maps, computes a set
+    of anchors
+    """
+
+    def __init__(
+        self,
+        sizes=(128, 256, 512),
+        aspect_ratios=(0.5, 1.0, 2.0),
+        anchor_strides=(8, 16, 32),
+        straddle_thresh=0,
+    ):
+        super(AnchorGenerator, self).__init__()
+
+        if len(anchor_strides) == 1:
+            anchor_stride = anchor_strides[0]
+            cell_anchors = [
+                generate_anchors(anchor_stride, sizes, aspect_ratios).float()
+            ]
+        else:
+            if len(anchor_strides) != len(sizes):
+                raise RuntimeError("FPN should have #anchor_strides == #sizes")
+            cell_anchors = [
+                generate_anchors(
+                    anchor_stride,
+                    size if isinstance(size, (tuple, list)) else (size,),
+                    aspect_ratios
+                ).float()
+                for anchor_stride, size in zip(anchor_strides, sizes)
+            ]
+        self.strides = anchor_strides
+        self.cell_anchors = BufferList(cell_anchors)
+        self.straddle_thresh = straddle_thresh
+
+    def num_anchors_per_location(self):
+        return [len(cell_anchors) for cell_anchors in self.cell_anchors]
+
+    def grid_anchors(self, grid_sizes):
+        anchors = []
+        for size, stride, base_anchors in zip(
+            grid_sizes, self.strides, self.cell_anchors
+        ):
+            grid_height, grid_width = size
+            device = base_anchors.device
+            shifts_x = torch.arange(
+                0, grid_width * stride, step=stride, dtype=torch.float32, device=device
+            )
+            shifts_y = torch.arange(
+                0, grid_height * stride, step=stride, dtype=torch.float32, device=device
+            )
+            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
+            shift_x = shift_x.reshape(-1)
+            shift_y = shift_y.reshape(-1)
+            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
+
+            anchors.append(
+                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
+            )
+
+        return anchors
+
+    def add_visibility_to(self, boxlist):
+        image_width, image_height = boxlist.size
+        anchors = boxlist.bbox
+        if self.straddle_thresh >= 0:
+            inds_inside = (
+                (anchors[..., 0] >= -self.straddle_thresh)
+                & (anchors[..., 1] >= -self.straddle_thresh)
+                & (anchors[..., 2] < image_width + self.straddle_thresh)
+                & (anchors[..., 3] < image_height + self.straddle_thresh)
+            )
+        else:
+            device = anchors.device
+            inds_inside = torch.ones(anchors.shape[0], dtype=torch.bool, device=device)
+        boxlist.add_field("visibility", inds_inside)
+
+    def forward(self, image_list, feature_maps):
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
+        anchors = []
+        if isinstance(image_list, ImageList):
+            for i, (image_height, image_width) in enumerate(image_list.image_sizes):
+                anchors_in_image = []
+                for anchors_per_feature_map in anchors_over_all_feature_maps:
+                    boxlist = BoxList(
+                        anchors_per_feature_map, (image_width, image_height), mode="xyxy"
+                    )
+                    self.add_visibility_to(boxlist)
+                    anchors_in_image.append(boxlist)
+                anchors.append(anchors_in_image)
+        else:
+            image_height, image_width = [int(x) for x in image_list.size()[-2:]]
+            anchors_in_image = []
+            for anchors_per_feature_map in anchors_over_all_feature_maps:
+                boxlist = BoxList(
+                    anchors_per_feature_map, (image_width, image_height), mode="xyxy"
+                )
+                self.add_visibility_to(boxlist)
+                anchors_in_image.append(boxlist)
+            anchors.append(anchors_in_image)
+        return anchors
+
+
+def make_anchor_generator(config):
+    anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES
+    aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS
+    anchor_stride = config.MODEL.RPN.ANCHOR_STRIDE
+    straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH
+
+    if config.MODEL.RPN.USE_FPN:
+        assert len(anchor_stride) == len(
+            anchor_sizes
+        ), "FPN should have len(ANCHOR_STRIDE) == len(ANCHOR_SIZES)"
+    else:
+        assert len(anchor_stride) == 1, "Non-FPN should have a single ANCHOR_STRIDE"
+    anchor_generator = AnchorGenerator(
+        anchor_sizes, aspect_ratios, anchor_stride, straddle_thresh
+    )
+    return anchor_generator
+
+
+def make_anchor_generator_complex(config):
+    anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES
+    aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS
+    anchor_strides = config.MODEL.RPN.ANCHOR_STRIDE
+    straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH
+    octave = config.MODEL.RPN.OCTAVE
+    scales_per_octave = config.MODEL.RPN.SCALES_PER_OCTAVE
+
+    if config.MODEL.RPN.USE_FPN:
+        assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now"
+        new_anchor_sizes = []
+        for size in anchor_sizes:
+            per_layer_anchor_sizes = []
+            for scale_per_octave in range(scales_per_octave):
+                octave_scale = octave ** (scale_per_octave / float(scales_per_octave))
+                per_layer_anchor_sizes.append(octave_scale * size)
+            new_anchor_sizes.append(tuple(per_layer_anchor_sizes))
+    else:
+        assert len(anchor_strides) == 1, "Non-FPN should have a single ANCHOR_STRIDE"
+        new_anchor_sizes = anchor_sizes
+
+    anchor_generator = AnchorGenerator(
+        tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh
+    )
+    return anchor_generator
+
+
+class CenterAnchorGenerator(nn.Module):
+    """
+    For a set of image sizes and feature maps, computes a set
+    of anchors
+    """
+
+    def __init__(
+            self,
+            sizes=(128, 256, 512),
+            aspect_ratios=(0.5, 1.0, 2.0),
+            anchor_strides=(8, 16, 32),
+            straddle_thresh=0,
+            anchor_shift=(0.0, 0.0, 0.0, 0.0),
+            use_relative=False
+    ):
+        super(CenterAnchorGenerator, self).__init__()
+
+        self.sizes = sizes
+        self.aspect_ratios = aspect_ratios
+        self.strides = anchor_strides
+        self.straddle_thresh = straddle_thresh
+        self.anchor_shift = anchor_shift
+        self.use_relative = use_relative
+
+    def add_visibility_to(self, boxlist):
+        image_width, image_height = boxlist.size
+        anchors = boxlist.bbox
+        if self.straddle_thresh >= 0:
+            inds_inside = (
+                    (anchors[..., 0] >= -self.straddle_thresh)
+                    & (anchors[..., 1] >= -self.straddle_thresh)
+                    & (anchors[..., 2] < image_width + self.straddle_thresh)
+                    & (anchors[..., 3] < image_height + self.straddle_thresh)
+            )
+        else:
+            device = anchors.device
+            inds_inside = torch.ones(anchors.shape[0], dtype=torch.uint8, device=device)
+        boxlist.add_field("visibility", inds_inside)
+
+    def forward(self, centers, image_sizes, feature_maps):
+        shift_left, shift_top, shift_right, shift_down = self.anchor_shift
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        anchors = []
+        for i, ((image_height, image_width), center_bbox) in enumerate(zip(image_sizes, centers)):
+            center = center_bbox.get_field("centers")
+            boxlist_per_level = []
+            for size, fsize in zip(self.sizes, grid_sizes):
+                for ratios in self.aspect_ratios:
+
+                    size_ratios = size*size / ratios
+                    ws = np.round(np.sqrt(size_ratios))
+                    hs = np.round(ws * ratios)
+
+                    anchors_per_level = torch.cat(
+                        (
+                            center[:,0,None] - 0.5 * (1 + shift_left) * (ws - 1),
+                            center[:,1,None] - 0.5 * (1 + shift_top) * (hs - 1),
+                            center[:,0,None] + 0.5 * (1 + shift_right) * (ws - 1),
+                            center[:,1,None] + 0.5 * (1 + shift_down) * (hs - 1),
+                        ),
+                        dim=1
+                    )
+                    boxlist = BoxList(anchors_per_level, (image_width, image_height), mode="xyxy")
+                    boxlist.add_field('cbox', center_bbox)
+                    self.add_visibility_to(boxlist)
+                    boxlist_per_level.append(boxlist)
+            if self.use_relative:
+                area = center_bbox.area()
+                for ratios in self.aspect_ratios:
+
+                    size_ratios = area / ratios
+                    ws = torch.round(torch.sqrt(size_ratios))
+                    hs = torch.round(ws * ratios)
+
+                    anchors_per_level = torch.stack(
+                        (
+                            center[:,0] - (1 + shift_left) * ws,
+                            center[:,1] - (1 + shift_top) * hs,
+                            center[:,0] + (1 + shift_right) * ws,
+                            center[:,1] + (1 + shift_down) * hs,
+                        ),
+                        dim=1
+                    )
+                    boxlist = BoxList(anchors_per_level, (image_width, image_height), mode="xyxy")
+                    boxlist.add_field('cbox', center_bbox)
+                    self.add_visibility_to(boxlist)
+                    boxlist_per_level.append(boxlist)
+            anchors_in_image = cat_boxlist(boxlist_per_level)
+            anchors.append(anchors_in_image)
+        return anchors
+
+
+def make_center_anchor_generator(config):
+    anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES
+    aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS
+    anchor_strides = config.MODEL.RPN.ANCHOR_STRIDE
+    straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH
+    octave = config.MODEL.RPN.OCTAVE
+    scales_per_octave = config.MODEL.RPN.SCALES_PER_OCTAVE
+    anchor_shift = config.MODEL.RPN.ANCHOR_SHIFT
+    use_relative = config.MODEL.RPN.USE_RELATIVE_SIZE
+
+    if config.MODEL.RPN.USE_FPN:
+        assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now"
+        new_anchor_sizes = []
+        for size in anchor_sizes:
+            per_layer_anchor_sizes = []
+            for scale_per_octave in range(scales_per_octave):
+                octave_scale = octave ** (scale_per_octave / float(scales_per_octave))
+                per_layer_anchor_sizes.append(octave_scale * size)
+            new_anchor_sizes.append(tuple(per_layer_anchor_sizes))
+    else:
+        assert len(anchor_strides) == 1, "Non-FPN should have a single ANCHOR_STRIDE"
+        new_anchor_sizes = anchor_sizes
+
+    anchor_generator = CenterAnchorGenerator(
+        tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh, anchor_shift, use_relative
+    )
+    return anchor_generator
+
+# Copyright (c) 2017-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+#
+# Based on:
+# --------------------------------------------------------
+# Faster R-CNN
+# Copyright (c) 2015 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ross Girshick and Sean Bell
+# --------------------------------------------------------
+
+
+# Verify that we compute the same anchors as Shaoqing's matlab implementation:
+#
+#    >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
+#    >> anchors
+#
+#    anchors =
+#
+#       -83   -39   100    56
+#      -175   -87   192   104
+#      -359  -183   376   200
+#       -55   -55    72    72
+#      -119  -119   136   136
+#      -247  -247   264   264
+#       -35   -79    52    96
+#       -79  -167    96   184
+#      -167  -343   184   360
+
+# array([[ -83.,  -39.,  100.,   56.],
+#        [-175.,  -87.,  192.,  104.],
+#        [-359., -183.,  376.,  200.],
+#        [ -55.,  -55.,   72.,   72.],
+#        [-119., -119.,  136.,  136.],
+#        [-247., -247.,  264.,  264.],
+#        [ -35.,  -79.,   52.,   96.],
+#        [ -79., -167.,   96.,  184.],
+#        [-167., -343.,  184.,  360.]])
+
+
+def generate_anchors(
+    stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
+):
+    """Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
+    are centered on stride / 2, have (approximate) sqrt areas of the specified
+    sizes, and aspect ratios as given.
+    """
+    return _generate_anchors(
+        stride,
+        np.array(sizes, dtype=np.float) / stride,
+        np.array(aspect_ratios, dtype=np.float),
+    )
+
+
+def _generate_anchors(base_size, scales, aspect_ratios):
+    """Generate anchor (reference) windows by enumerating aspect ratios X
+    scales wrt a reference (0, 0, base_size - 1, base_size - 1) window.
+    """
+    anchor = np.array([1, 1, base_size, base_size], dtype=np.float) - 1
+    anchors = _ratio_enum(anchor, aspect_ratios)
+    anchors = np.vstack(
+        [_scale_enum(anchors[i, :], scales) for i in range(anchors.shape[0])]
+    )
+    return torch.from_numpy(anchors)
+
+
+def _whctrs(anchor):
+    """Return width, height, x center, and y center for an anchor (window)."""
+    w = anchor[2] - anchor[0] + 1
+    h = anchor[3] - anchor[1] + 1
+    x_ctr = anchor[0] + 0.5 * (w - 1)
+    y_ctr = anchor[1] + 0.5 * (h - 1)
+    return w, h, x_ctr, y_ctr
+
+
+def _mkanchors(ws, hs, x_ctr, y_ctr):
+    """Given a vector of widths (ws) and heights (hs) around a center
+    (x_ctr, y_ctr), output a set of anchors (windows).
+    """
+    ws = ws[:, np.newaxis]
+    hs = hs[:, np.newaxis]
+    anchors = np.hstack(
+        (
+            x_ctr - 0.5 * (ws - 1),
+            y_ctr - 0.5 * (hs - 1),
+            x_ctr + 0.5 * (ws - 1),
+            y_ctr + 0.5 * (hs - 1),
+        )
+    )
+    return anchors
+
+
+def _ratio_enum(anchor, ratios):
+    """Enumerate a set of anchors for each aspect ratio wrt an anchor."""
+    w, h, x_ctr, y_ctr = _whctrs(anchor)
+    size = w * h
+    size_ratios = size / ratios
+    ws = np.round(np.sqrt(size_ratios))
+    hs = np.round(ws * ratios)
+    anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
+    return anchors
+
+
+def _scale_enum(anchor, scales):
+    """Enumerate a set of anchors for each scale wrt an anchor."""
+    w, h, x_ctr, y_ctr = _whctrs(anchor)
+    ws = w * scales
+    hs = h * scales
+    anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
+    return anchors
diff --git a/maskrcnn_benchmark/modeling/rpn/atss.py b/maskrcnn_benchmark/modeling/rpn/atss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1132522ad4491477c0dd320d43b78351daf69b4
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/atss.py
@@ -0,0 +1,233 @@
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .inference import make_atss_postprocessor
+from .loss import make_atss_loss_evaluator
+
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.layers import Scale, DFConv2d, DYReLU, SELayer
+from .anchor_generator import make_anchor_generator_complex
+
+
+class BoxCoder(object):
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def encode(self, gt_boxes, anchors):
+
+        TO_REMOVE = 1  # TODO remove
+        ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE
+        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE
+        gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2
+        gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+        targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+        targets_dw = ww * torch.log(gt_widths / ex_widths)
+        targets_dh = wh * torch.log(gt_heights / ex_heights)
+        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+
+        return targets
+
+    def decode(self, preds, anchors):
+
+        anchors = anchors.to(preds.dtype)
+
+        TO_REMOVE = 1  # TODO remove
+        widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        dx = preds[:, 0::4] / wx
+        dy = preds[:, 1::4] / wy
+        dw = preds[:, 2::4] / ww
+        dh = preds[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=math.log(1000. / 16))
+        dh = torch.clamp(dh, max=math.log(1000. / 16))
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        pred_boxes = torch.zeros_like(preds)
+        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1)
+        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1)
+        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1)
+        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1)
+
+        return pred_boxes
+
+
+class ATSSHead(torch.nn.Module):
+    def __init__(self, cfg):
+        super(ATSSHead, self).__init__()
+        self.cfg = cfg
+        num_classes = cfg.MODEL.ATSS.NUM_CLASSES - 1
+        num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        channels = cfg.MODEL.ATSS.CHANNELS
+        use_gn = cfg.MODEL.ATSS.USE_GN
+        use_bn = cfg.MODEL.ATSS.USE_BN
+        use_dcn_in_tower = cfg.MODEL.ATSS.USE_DFCONV
+        use_dyrelu = cfg.MODEL.ATSS.USE_DYRELU
+        use_se = cfg.MODEL.ATSS.USE_SE
+
+        cls_tower = []
+        bbox_tower = []
+        for i in range(cfg.MODEL.ATSS.NUM_CONVS):
+            if use_dcn_in_tower and \
+                    i == cfg.MODEL.ATSS.NUM_CONVS - 1:
+                conv_func = DFConv2d
+            else:
+                conv_func = nn.Conv2d
+
+            cls_tower.append(
+                conv_func(
+                    in_channels if i==0 else channels,
+                    channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=True
+                )
+            )
+            if use_gn:
+                cls_tower.append(nn.GroupNorm(32, channels))
+            if use_bn:
+                cls_tower.append(nn.BatchNorm2d(channels))
+            if use_se:
+                cls_tower.append(SELayer(channels))
+            if use_dyrelu:
+                cls_tower.append(DYReLU(channels, channels))
+            else:
+                cls_tower.append(nn.ReLU())
+
+            bbox_tower.append(
+                conv_func(
+                    in_channels if i == 0 else channels,
+                    channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=True
+                )
+            )
+            if use_gn:
+                bbox_tower.append(nn.GroupNorm(32, channels))
+            if use_bn:
+                bbox_tower.append(nn.BatchNorm2d(channels))
+            if use_se:
+                bbox_tower.append(SELayer(channels))
+            if use_dyrelu:
+                bbox_tower.append(DYReLU(channels, channels))
+            else:
+                bbox_tower.append(nn.ReLU())
+
+        self.add_module('cls_tower', nn.Sequential(*cls_tower))
+        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
+        self.cls_logits = nn.Conv2d(
+            channels, num_anchors * num_classes, kernel_size=3, stride=1,
+            padding=1
+        )
+        self.bbox_pred = nn.Conv2d(
+            channels, num_anchors * 4, kernel_size=3, stride=1,
+            padding=1
+        )
+        self.centerness = nn.Conv2d(
+            channels, num_anchors * 1, kernel_size=3, stride=1,
+            padding=1
+        )
+
+        # initialization
+        for modules in [self.cls_tower, self.bbox_tower,
+                        self.cls_logits, self.bbox_pred,
+                        self.centerness]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    torch.nn.init.constant_(l.bias, 0)
+
+        # initialize the bias for focal loss
+        prior_prob = cfg.MODEL.ATSS.PRIOR_PROB
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
+
+        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        centerness = []
+        for l, feature in enumerate(x):
+            cls_tower = self.cls_tower(feature)
+            box_tower = self.bbox_tower(feature)
+
+            logits.append(self.cls_logits(cls_tower))
+
+            bbox_pred = self.scales[l](self.bbox_pred(box_tower))
+            bbox_reg.append(bbox_pred)
+
+            centerness.append(self.centerness(box_tower))
+        return logits, bbox_reg, centerness
+
+
+class ATSSModule(torch.nn.Module):
+
+    def __init__(self, cfg):
+        super(ATSSModule, self).__init__()
+        self.cfg = cfg
+        self.head = ATSSHead(cfg)
+        box_coder = BoxCoder(cfg)
+        self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder)
+        self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True)
+        self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False)
+        self.anchor_generator = make_anchor_generator_complex(cfg)
+
+    def forward(self, images, features, targets=None):
+        box_cls, box_regression, centerness = self.head(features)
+        anchors = self.anchor_generator(images, features)
+ 
+        if self.training:
+            return self._forward_train(box_cls, box_regression, centerness, targets, anchors)
+        else:
+            return self._forward_test(box_cls, box_regression, centerness, anchors)
+
+    def _forward_train(self, box_cls, box_regression, centerness, targets, anchors):
+        loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
+            box_cls, box_regression, centerness, targets, anchors
+        )
+        losses = {
+            "loss_cls": loss_box_cls,
+            "loss_reg": loss_box_reg,
+            "loss_centerness": loss_centerness
+        }
+        if self.cfg.MODEL.RPN_ONLY:
+            return None, losses
+        else:
+            boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors)
+            train_boxes = []
+            for b, a in zip(boxes, anchors):
+                a = cat_boxlist(a)
+                b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device))
+                del b.extra_fields['scores']
+                del b.extra_fields['labels']
+                train_boxes.append(cat_boxlist([b, a]))
+            return train_boxes, losses
+
+    def _forward_test(self, box_cls, box_regression, centerness, anchors):
+        boxes = self.box_selector_test(box_cls, box_regression, centerness, anchors)
+        return boxes, {}
diff --git a/maskrcnn_benchmark/modeling/rpn/dyhead.py b/maskrcnn_benchmark/modeling/rpn/dyhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..e84cd3a9d28ed337bece0c87689301b865642324
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/dyhead.py
@@ -0,0 +1,377 @@
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .inference import make_atss_postprocessor
+from .loss import make_atss_loss_evaluator
+from .anchor_generator import make_anchor_generator_complex
+
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv
+from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
+from maskrcnn_benchmark.modeling.backbone.fbnet import *
+
+
+class h_sigmoid(nn.Module):
+    def __init__(self, inplace=True, h_max=1):
+        super(h_sigmoid, self).__init__()
+        self.relu = nn.ReLU6(inplace=inplace)
+        self.h_max = h_max
+
+    def forward(self, x):
+        return self.relu(x + 3) * self.h_max / 6
+
+
+class BoxCoder(object):
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def encode(self, gt_boxes, anchors):
+        TO_REMOVE = 1  # TODO remove
+        ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE
+        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE
+        gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2
+        gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+        targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+        targets_dw = ww * torch.log(gt_widths / ex_widths)
+        targets_dh = wh * torch.log(gt_heights / ex_heights)
+        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+
+        return targets
+
+    def decode(self, preds, anchors):
+        anchors = anchors.to(preds.dtype)
+
+        TO_REMOVE = 1  # TODO remove
+        widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        dx = preds[:, 0::4] / wx
+        dy = preds[:, 1::4] / wy
+        dw = preds[:, 2::4] / ww
+        dh = preds[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=math.log(1000. / 16))
+        dh = torch.clamp(dh, max=math.log(1000. / 16))
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        pred_boxes = torch.zeros_like(preds)
+        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1)
+        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1)
+        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1)
+        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1)
+
+        return pred_boxes
+
+
+class Conv3x3Norm(torch.nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 groups=1,
+                 deformable=False,
+                 bn_type=None):
+        super(Conv3x3Norm, self).__init__()
+
+        if deformable:
+            self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1,
+                                            groups=groups)
+        else:
+            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups)
+
+        if isinstance(bn_type, (list, tuple)):
+            assert len(bn_type) == 2
+            assert bn_type[0] == "gn"
+            gn_group = bn_type[1]
+            bn_type = bn_type[0]
+
+        if bn_type == "bn":
+            bn_op = nn.BatchNorm2d(out_channels)
+        elif bn_type == "sbn":
+            bn_op = nn.SyncBatchNorm(out_channels)
+        elif bn_type == "nsbn":
+            bn_op = NaiveSyncBatchNorm2d(out_channels)
+        elif bn_type == "gn":
+            bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels)
+        elif bn_type == "af":
+            bn_op = FrozenBatchNorm2d(out_channels)
+        if bn_type is not None:
+            self.bn = bn_op
+        else:
+            self.bn = None
+
+    def forward(self, input, **kwargs):
+        x = self.conv(input, **kwargs)
+        if self.bn:
+            x = self.bn(x)
+        return x
+
+
+class DyConv(torch.nn.Module):
+    def __init__(self,
+                 in_channels=256,
+                 out_channels=256,
+                 conv_func=nn.Conv2d,
+                 use_dyfuse=True,
+                 use_dyrelu=False,
+                 use_deform=False
+                 ):
+        super(DyConv, self).__init__()
+
+        self.DyConv = nn.ModuleList()
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 2))
+
+        if use_dyfuse:
+            self.AttnConv = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(in_channels, 1, kernel_size=1),
+                nn.ReLU(inplace=True))
+            self.h_sigmoid = h_sigmoid()
+        else:
+            self.AttnConv = None
+
+        if use_dyrelu:
+            self.relu = DYReLU(in_channels, out_channels)
+        else:
+            self.relu = nn.ReLU()
+
+        if use_deform:
+            self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
+        else:
+            self.offset = None
+
+        self.init_weights()
+
+    def init_weights(self):
+        for m in self.DyConv.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight.data, 0, 0.01)
+                if m.bias is not None:
+                    m.bias.data.zero_()
+        if self.AttnConv is not None:
+            for m in self.AttnConv.modules():
+                if isinstance(m, nn.Conv2d):
+                    nn.init.normal_(m.weight.data, 0, 0.01)
+                    if m.bias is not None:
+                        m.bias.data.zero_()
+
+    def forward(self, x):
+        next_x = []
+        for level, feature in enumerate(x):
+
+            conv_args = dict()
+            if self.offset is not None:
+                offset_mask = self.offset(feature)
+                offset = offset_mask[:, :18, :, :]
+                mask = offset_mask[:, 18:, :, :].sigmoid()
+                conv_args = dict(offset=offset, mask=mask)
+
+            temp_fea = [self.DyConv[1](feature, **conv_args)]
+
+            if level > 0:
+                temp_fea.append(self.DyConv[2](x[level - 1], **conv_args))
+            if level < len(x) - 1:
+                temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args),
+                                                    size=[feature.size(2), feature.size(3)]))
+            mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
+
+            if self.AttnConv is not None:
+                attn_fea = []
+                res_fea = []
+                for fea in temp_fea:
+                    res_fea.append(fea)
+                    attn_fea.append(self.AttnConv(fea))
+
+                res_fea = torch.stack(res_fea)
+                spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
+
+                mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
+
+            next_x.append(mean_fea)
+
+        next_x = [self.relu(item) for item in next_x]
+        return next_x
+
+
+class DyHead(torch.nn.Module):
+    def __init__(self, cfg):
+        super(DyHead, self).__init__()
+        self.cfg = cfg
+        num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1
+        num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        channels = cfg.MODEL.DYHEAD.CHANNELS
+        if cfg.MODEL.DYHEAD.USE_GN:
+            bn_type = ['gn', cfg.MODEL.GROUP_NORM.NUM_GROUPS]
+        elif cfg.MODEL.DYHEAD.USE_NSYNCBN:
+            bn_type = 'nsbn'
+        elif cfg.MODEL.DYHEAD.USE_SYNCBN:
+            bn_type = 'sbn'
+        else:
+            bn_type = None
+
+        use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU
+        use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE
+        use_deform = cfg.MODEL.DYHEAD.USE_DFCONV
+
+        if cfg.MODEL.DYHEAD.CONV_FUNC:
+            conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type)
+        else:
+            conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type)
+
+        dyhead_tower = []
+        for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
+            dyhead_tower.append(
+                DyConv(
+                    in_channels if i == 0 else channels,
+                    channels,
+                    conv_func=conv_func,
+                    use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu,
+                    use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse,
+                    use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform,
+                )
+            )
+
+        self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
+        if cfg.MODEL.DYHEAD.COSINE_SCALE <= 0:
+            self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1)
+            self.cls_logits_bias = None
+        else:
+            self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1, bias=False)
+            self.cls_logits_bias = nn.Parameter(torch.zeros(num_anchors * num_classes, requires_grad=True))
+            self.cosine_scale = nn.Parameter(torch.ones(1) * cfg.MODEL.DYHEAD.COSINE_SCALE)
+        self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1)
+        self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1)
+
+        # initialization
+        for modules in [self.cls_logits, self.bbox_pred,
+                        self.centerness]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    if hasattr(l, 'bias') and l.bias is not None:
+                        torch.nn.init.constant_(l.bias, 0)
+
+        # initialize the bias for focal loss
+        prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        if self.cls_logits_bias is None:
+            torch.nn.init.constant_(self.cls_logits.bias, bias_value)
+        else:
+            torch.nn.init.constant_(self.cls_logits_bias, bias_value)
+
+        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
+
+    def extract_feature(self, x):
+        output = []
+        for i in range(len(self.dyhead_tower)):
+            x = self.dyhead_tower[i](x)
+            output.append(x)
+        return output
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        centerness = []
+
+        dyhead_tower = self.dyhead_tower(x)
+
+        for l, feature in enumerate(x):
+            if self.cls_logits_bias is None:
+                logit = self.cls_logits(dyhead_tower[l])
+            else:
+                # CosineSimOutputLayers: https://github.com/ucbdrive/few-shot-object-detection/blob/master/fsdet/modeling/roi_heads/fast_rcnn.py#L448-L464
+                # normalize the input x along the `channel` dimension
+                x_norm = torch.norm(dyhead_tower[l], p=2, dim=1, keepdim=True).expand_as(dyhead_tower[l])
+                x_normalized = dyhead_tower[l].div(x_norm + 1e-5)
+                # normalize weight
+                temp_norm = (
+                    torch.norm(self.cls_logits.weight.data, p=2, dim=1, keepdim=True)
+                        .expand_as(self.cls_logits.weight.data)
+                )
+                self.cls_logits.weight.data = self.cls_logits.weight.data.div(
+                    temp_norm + 1e-5
+                )
+                cos_dist = self.cls_logits(x_normalized)
+                logit = self.cosine_scale * cos_dist + self.cls_logits_bias.reshape(1, len(self.cls_logits_bias), 1, 1)
+            logits.append(logit)
+
+            bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower[l]))
+            bbox_reg.append(bbox_pred)
+
+            centerness.append(self.centerness(dyhead_tower[l]))
+        return logits, bbox_reg, centerness
+
+
+class DyHeadModule(torch.nn.Module):
+
+    def __init__(self, cfg):
+        super(DyHeadModule, self).__init__()
+        self.cfg = cfg
+        self.head = DyHead(cfg)
+        box_coder = BoxCoder(cfg)
+        self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder)
+        self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True)
+        self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False)
+        self.anchor_generator = make_anchor_generator_complex(cfg)
+
+    def forward(self, images, features, targets=None):
+        box_cls, box_regression, centerness = self.head(features)
+        anchors = self.anchor_generator(images, features)
+
+        if self.training:
+            return self._forward_train(box_cls, box_regression, centerness, targets, anchors)
+        else:
+            return self._forward_test(box_cls, box_regression, centerness, anchors)
+
+    def _forward_train(self, box_cls, box_regression, centerness, targets, anchors):
+        loss_box_cls, loss_box_reg, loss_centerness, _, _, _, _ = self.loss_evaluator(
+            box_cls, box_regression, centerness, targets, anchors
+        )
+        losses = {
+            "loss_cls": loss_box_cls,
+            "loss_reg": loss_box_reg,
+            "loss_centerness": loss_centerness
+        }
+        if self.cfg.MODEL.RPN_ONLY:
+            return None, losses
+        else:
+            # boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors)
+            boxes = self.box_selector_train(box_regression, centerness, anchors, box_cls)
+            train_boxes = []
+            # for b, a in zip(boxes, anchors):
+            #     a = cat_boxlist(a)
+            #     b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device))
+            #     del b.extra_fields['scores']
+            #     del b.extra_fields['labels']
+            #     train_boxes.append(cat_boxlist([b, a]))
+            for b, t in zip(boxes, targets):
+                tb = t.copy_with_fields(["labels"])
+                tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device))
+                train_boxes.append(cat_boxlist([b, tb]))
+            return train_boxes, losses
+
+    def _forward_test(self, box_cls, box_regression, centerness, anchors):
+        boxes = self.box_selector_test(box_regression, centerness, anchors, box_cls)
+        return boxes, {}
diff --git a/maskrcnn_benchmark/modeling/rpn/fcos.py b/maskrcnn_benchmark/modeling/rpn/fcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..c69dab0fd86d7b891ee001228368294a0fd56ae4
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/fcos.py
@@ -0,0 +1,236 @@
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.layers import Scale, DFConv2d
+from .loss import make_fcos_loss_evaluator
+from .anchor_generator import make_center_anchor_generator
+from .inference import make_fcos_postprocessor
+
+
+@registry.RPN_HEADS.register("FCOSHead")
+class FCOSHead(torch.nn.Module):
+    def __init__(self, cfg):
+
+        super(FCOSHead, self).__init__()
+        # TODO: Implement the sigmoid version first.
+        num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        use_gn = cfg.MODEL.FCOS.USE_GN
+        use_bn = cfg.MODEL.FCOS.USE_BN
+        use_dcn_in_tower = cfg.MODEL.FCOS.USE_DFCONV
+        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
+        self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
+        self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG
+
+        cls_tower = []
+        bbox_tower = []
+        for i in range(cfg.MODEL.FCOS.NUM_CONVS):
+            if use_dcn_in_tower and \
+                    i == cfg.MODEL.FCOS.NUM_CONVS - 1:
+                conv_func = DFConv2d
+            else:
+                conv_func = nn.Conv2d
+
+            cls_tower.append(
+                conv_func(
+                    in_channels,
+                    in_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=True
+                )
+            )
+            if use_gn:
+                cls_tower.append(nn.GroupNorm(32, in_channels))
+            if use_bn:
+                cls_tower.append(nn.BatchNorm2d(in_channels))
+            cls_tower.append(nn.ReLU())
+
+            bbox_tower.append(
+                conv_func(
+                    in_channels,
+                    in_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=True
+                )
+            )
+            if use_gn:
+                bbox_tower.append(nn.GroupNorm(32, in_channels))
+            if use_bn:
+                bbox_tower.append(nn.BatchNorm2d(in_channels))
+            bbox_tower.append(nn.ReLU())
+
+        self.add_module('cls_tower', nn.Sequential(*cls_tower))
+        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
+        self.cls_logits = nn.Conv2d(
+            in_channels, num_classes, kernel_size=3, stride=1,
+            padding=1
+        )
+        self.bbox_pred = nn.Conv2d(
+            in_channels, 4, kernel_size=3, stride=1,
+            padding=1
+        )
+        self.centerness = nn.Conv2d(
+            in_channels, 1, kernel_size=3, stride=1,
+            padding=1
+        )
+
+        # initialization
+        for modules in [self.cls_tower, self.bbox_tower,
+                        self.cls_logits, self.bbox_pred,
+                        self.centerness]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    torch.nn.init.constant_(l.bias, 0)
+
+        # initialize the bias for focal loss
+        prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
+
+        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        centerness = []
+        for l, feature in enumerate(x):
+            cls_tower = self.cls_tower(feature)
+            box_tower = self.bbox_tower(feature)
+
+            logits.append(self.cls_logits(cls_tower))
+            if self.centerness_on_reg:
+                centerness.append(self.centerness(box_tower))
+            else:
+                centerness.append(self.centerness(cls_tower))
+
+            bbox_pred = self.scales[l](self.bbox_pred(box_tower))
+            if self.norm_reg_targets:
+                bbox_pred = F.relu(bbox_pred)
+                if self.training:
+                    bbox_reg.append(bbox_pred)
+                else:
+                    bbox_reg.append(bbox_pred * self.fpn_strides[l])
+            else:
+                bbox_reg.append(torch.exp(bbox_pred))
+        return logits, bbox_reg, centerness
+
+
+class FCOSModule(torch.nn.Module):
+    """
+    Module for FCOS computation. Takes feature maps from the backbone and
+    FCOS outputs and losses. Only Test on FPN now.
+    """
+
+    def __init__(self, cfg):
+        super(FCOSModule, self).__init__()
+
+        head = FCOSHead(cfg)
+
+        box_selector_train = make_fcos_postprocessor(cfg, is_train=True)
+        box_selector_test = make_fcos_postprocessor(cfg, is_train=False)
+
+        loss_evaluator = make_fcos_loss_evaluator(cfg)
+
+        self.cfg = cfg
+        self.head = head
+        self.box_selector_train = box_selector_train
+        self.box_selector_test = box_selector_test
+        self.loss_evaluator = loss_evaluator
+        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
+        if not cfg.MODEL.RPN_ONLY:
+            self.anchor_generator = make_center_anchor_generator(cfg)
+
+
+    def forward(self, images, features, targets=None):
+        """
+        Arguments:
+            images (ImageList): images for which we want to compute the predictions
+            features (list[Tensor]): features computed from the images that are
+                used for computing the predictions. Each tensor in the list
+                correspond to different feature levels
+            targets (list[BoxList): ground-truth boxes present in the image (optional)
+
+        Returns:
+            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
+                image.
+            losses (dict[Tensor]): the losses for the model during training. During
+                testing, it is an empty dict.
+        """
+        box_cls, box_regression, centerness = self.head(features)
+        locations = self.compute_locations(features)
+        if self.training and targets is not None:
+            return self._forward_train(
+                locations, box_cls, box_regression,
+                centerness, targets, images.image_sizes
+            )
+        else:
+            return self._forward_test(
+                locations, box_cls, box_regression,
+                centerness, images.image_sizes
+            )
+
+    def _forward_train(self, locations, box_cls, box_regression, centerness, targets, image_sizes=None):
+        loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
+            locations, box_cls, box_regression, centerness, targets
+        )
+        losses = {
+            "loss_cls": loss_box_cls,
+            "loss_reg": loss_box_reg,
+            "loss_centerness": loss_centerness
+        }
+        if self.cfg.MODEL.RPN_ONLY:
+            return None, losses
+        else:
+            boxes = self.box_selector_train(
+                locations, box_cls, box_regression,
+                centerness, image_sizes
+            )
+            proposals = self.anchor_generator(boxes, image_sizes, centerness)
+            return proposals, losses
+
+    def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes):
+        boxes = self.box_selector_test(
+            locations, box_cls, box_regression,
+            centerness, image_sizes
+        )
+        if not self.cfg.MODEL.RPN_ONLY:
+            boxes = self.anchor_generator(boxes, image_sizes, centerness)
+        return boxes, {}
+
+    def compute_locations(self, features):
+        locations = []
+        for level, feature in enumerate(features):
+            h, w = feature.size()[-2:]
+            locations_per_level = self.compute_locations_per_level(
+                h, w, self.fpn_strides[level],
+                feature.device
+            )
+            locations.append(locations_per_level)
+        return locations
+
+    def compute_locations_per_level(self, h, w, stride, device):
+        shifts_x = torch.arange(
+            0, w * stride, step=stride,
+            dtype=torch.float32, device=device
+        )
+        shifts_y = torch.arange(
+            0, h * stride, step=stride,
+            dtype=torch.float32, device=device
+        )
+        shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
+        shift_x = shift_x.reshape(-1)
+        shift_y = shift_y.reshape(-1)
+        locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
+        return locations
+
+
+
+
diff --git a/maskrcnn_benchmark/modeling/rpn/inference.py b/maskrcnn_benchmark/modeling/rpn/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4b1430e8add3ca92dace5ebf5bc8ba4261729c
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/inference.py
@@ -0,0 +1,850 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+
+import torch
+
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+from maskrcnn_benchmark.structures.bounding_box import BoxList, _onnx_clip_boxes_to_image
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms
+from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
+
+from ..utils import permute_and_flatten
+import pdb
+
+class RPNPostProcessor(torch.nn.Module):
+    """
+    Performs post-processing on the outputs of the RPN boxes, before feeding the
+    proposals to the heads
+    """
+
+    def __init__(
+            self,
+            pre_nms_top_n,
+            post_nms_top_n,
+            nms_thresh,
+            min_size,
+            box_coder=None,
+            fpn_post_nms_top_n=None,
+            onnx=False
+    ):
+        """
+        Arguments:
+            pre_nms_top_n (int)
+            post_nms_top_n (int)
+            nms_thresh (float)
+            min_size (int)
+            box_coder (BoxCoder)
+            fpn_post_nms_top_n (int)
+        """
+        super(RPNPostProcessor, self).__init__()
+        self.pre_nms_top_n = pre_nms_top_n
+        self.post_nms_top_n = post_nms_top_n
+        self.nms_thresh = nms_thresh
+        self.min_size = min_size
+        self.onnx = onnx
+
+        if box_coder is None:
+            box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+        self.box_coder = box_coder
+
+        if fpn_post_nms_top_n is None:
+            fpn_post_nms_top_n = post_nms_top_n
+        self.fpn_post_nms_top_n = fpn_post_nms_top_n
+
+    def add_gt_proposals(self, proposals, targets):
+        """
+        Arguments:
+            proposals: list[BoxList]
+            targets: list[BoxList]
+        """
+        # Get the device we're operating on
+        device = proposals[0].bbox.device
+
+        gt_boxes = [target.copy_with_fields([]) for target in targets]
+
+        # later cat of bbox requires all fields to be present for all bbox
+        # so we need to add a dummy for objectness that's missing
+        for gt_box in gt_boxes:
+            gt_box.add_field("objectness", torch.ones(len(gt_box), device=device))
+
+        proposals = [
+            cat_boxlist((proposal, gt_box))
+            for proposal, gt_box in zip(proposals, gt_boxes)
+        ]
+
+        return proposals
+
+    def forward_for_single_feature_map(self, anchors, objectness, box_regression):
+        """
+        Arguments:
+            anchors: list[BoxList]
+            objectness: tensor of size N, A, H, W
+            box_regression: tensor of size N, A * 4, H, W
+        """
+        device = objectness.device
+        N, A, H, W = objectness.shape
+
+        # put in the same format as anchors
+        objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1)
+        objectness = objectness.sigmoid()
+        box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2)
+        box_regression = box_regression.reshape(N, -1, 4)
+
+        num_anchors = A * H * W
+
+        pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
+        objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True)
+
+        batch_idx = torch.arange(N, device=device)[:, None]
+        box_regression = box_regression[batch_idx, topk_idx]
+
+        image_shapes = [box.size for box in anchors]
+        concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)
+        concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx]
+
+        proposals = self.box_coder.decode(
+            box_regression.view(-1, 4), concat_anchors.view(-1, 4)
+        )
+
+        proposals = proposals.view(N, -1, 4)
+
+        result = []
+        for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
+            if self.onnx:
+                proposal = _onnx_clip_boxes_to_image(proposal, im_shape)
+                boxlist = BoxList(proposal, im_shape, mode="xyxy")
+            else:
+                boxlist = BoxList(proposal, im_shape, mode="xyxy")
+                boxlist = boxlist.clip_to_image(remove_empty=False)
+
+            boxlist.add_field("objectness", score)
+            boxlist = remove_small_boxes(boxlist, self.min_size)
+            boxlist = boxlist_nms(
+                boxlist,
+                self.nms_thresh,
+                max_proposals=self.post_nms_top_n,
+                score_field="objectness",
+            )
+            result.append(boxlist)
+        return result
+
+    def forward(self, anchors, objectness, box_regression, targets=None):
+        """
+        Arguments:
+            anchors: list[list[BoxList]]
+            objectness: list[tensor]
+            box_regression: list[tensor]
+
+        Returns:
+            boxlists (list[BoxList]): the post-processed anchors, after
+                applying box decoding and NMS
+        """
+        sampled_boxes = []
+        num_levels = len(objectness)
+        anchors = list(zip(*anchors))
+        for a, o, b in zip(anchors, objectness, box_regression):
+            sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
+
+        boxlists = list(zip(*sampled_boxes))
+        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
+
+        if num_levels > 1:
+            boxlists = self.select_over_all_levels(boxlists)
+
+        # append ground-truth bboxes to proposals
+        if self.training and targets is not None:
+            boxlists = self.add_gt_proposals(boxlists, targets)
+
+        return boxlists
+
+    def select_over_all_levels(self, boxlists):
+        num_images = len(boxlists)
+        # different behavior during training and during testing:
+        # during training, post_nms_top_n is over *all* the proposals combined, while
+        # during testing, it is over the proposals for each image
+        # TODO resolve this difference and make it consistent. It should be per image,
+        # and not per batch
+        if self.training:
+            objectness = torch.cat(
+                [boxlist.get_field("objectness") for boxlist in boxlists], dim=0
+            )
+            box_sizes = [len(boxlist) for boxlist in boxlists]
+            post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
+            _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
+            inds_mask = torch.zeros_like(objectness, dtype=torch.bool)
+            inds_mask[inds_sorted] = 1
+            inds_mask = inds_mask.split(box_sizes)
+            for i in range(num_images):
+                boxlists[i] = boxlists[i][inds_mask[i]]
+        else:
+            for i in range(num_images):
+                objectness = boxlists[i].get_field("objectness")
+                post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
+                _, inds_sorted = torch.topk(
+                    objectness, post_nms_top_n, dim=0, sorted=True
+                )
+                boxlists[i] = boxlists[i][inds_sorted]
+        return boxlists
+
+
+def make_rpn_postprocessor(config, rpn_box_coder, is_train):
+    fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN
+    if not is_train:
+        fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST
+
+    pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN
+    post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN
+    if not is_train:
+        pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST
+        post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST
+    nms_thresh = config.MODEL.RPN.NMS_THRESH
+    min_size = config.MODEL.RPN.MIN_SIZE
+    onnx = config.MODEL.ONNX
+    box_selector = RPNPostProcessor(
+        pre_nms_top_n=pre_nms_top_n,
+        post_nms_top_n=post_nms_top_n,
+        nms_thresh=nms_thresh,
+        min_size=min_size,
+        box_coder=rpn_box_coder,
+        fpn_post_nms_top_n=fpn_post_nms_top_n,
+        onnx=onnx
+    )
+    return box_selector
+
+
+class RetinaPostProcessor(torch.nn.Module):
+    """
+    Performs post-processing on the outputs of the RetinaNet boxes.
+    This is only used in the testing.
+    """
+
+    def __init__(
+            self,
+            pre_nms_thresh,
+            pre_nms_top_n,
+            nms_thresh,
+            fpn_post_nms_top_n,
+            min_size,
+            num_classes,
+            box_coder=None,
+    ):
+        """
+        Arguments:
+            pre_nms_thresh (float)
+            pre_nms_top_n (int)
+            nms_thresh (float)
+            fpn_post_nms_top_n (int)
+            min_size (int)
+            num_classes (int)
+            box_coder (BoxCoder)
+        """
+        super(RetinaPostProcessor, self).__init__()
+        self.pre_nms_thresh = pre_nms_thresh
+        self.pre_nms_top_n = pre_nms_top_n
+        self.nms_thresh = nms_thresh
+        self.fpn_post_nms_top_n = fpn_post_nms_top_n
+        self.min_size = min_size
+        self.num_classes = num_classes
+
+        if box_coder is None:
+            box_coder = BoxCoder(weights=(10., 10., 5., 5.))
+        self.box_coder = box_coder
+
+    def forward_for_single_feature_map(self, anchors, box_cls, box_regression):
+        """
+        Arguments:
+            anchors: list[BoxList]
+            box_cls: tensor of size N, A * C, H, W
+            box_regression: tensor of size N, A * 4, H, W
+        """
+        device = box_cls.device
+        N, _, H, W = box_cls.shape
+        A = box_regression.size(1) // 4
+        C = box_cls.size(1) // A
+
+        # put in the same format as anchors
+        box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
+        box_cls = box_cls.sigmoid()
+
+        box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
+        box_regression = box_regression.reshape(N, -1, 4)
+
+        num_anchors = A * H * W
+
+        candidate_inds = box_cls > self.pre_nms_thresh
+
+        pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
+        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
+
+        results = []
+        for per_box_cls, per_box_regression, per_pre_nms_top_n, \
+            per_candidate_inds, per_anchors in zip(
+            box_cls,
+            box_regression,
+            pre_nms_top_n,
+            candidate_inds,
+            anchors):
+            # Sort and select TopN
+            # TODO most of this can be made out of the loop for
+            # all images.
+            # TODO:Yang: Not easy to do. Because the numbers of detections are
+            # different in each image. Therefore, this part needs to be done
+            # per image.
+            per_box_cls = per_box_cls[per_candidate_inds]
+
+            per_box_cls, top_k_indices = \
+                per_box_cls.topk(per_pre_nms_top_n, sorted=False)
+
+            per_candidate_nonzeros = \
+                per_candidate_inds.nonzero()[top_k_indices, :]
+
+            per_box_loc = per_candidate_nonzeros[:, 0]
+            per_class = per_candidate_nonzeros[:, 1]
+            per_class += 1
+
+            detections = self.box_coder.decode(
+                per_box_regression[per_box_loc, :].view(-1, 4),
+                per_anchors.bbox[per_box_loc, :].view(-1, 4)
+            )
+
+            boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
+            boxlist.add_field("labels", per_class)
+            boxlist.add_field("scores", per_box_cls)
+            boxlist = boxlist.clip_to_image(remove_empty=False)
+            boxlist = remove_small_boxes(boxlist, self.min_size)
+            results.append(boxlist)
+
+        return results
+
+    # TODO very similar to filter_results from PostProcessor
+    # but filter_results is per image
+    # TODO Yang: solve this issue in the future. No good solution
+    # right now.
+    def select_over_all_levels(self, boxlists):
+        num_images = len(boxlists)
+        results = []
+        for i in range(num_images):
+            scores = boxlists[i].get_field("scores")
+            labels = boxlists[i].get_field("labels")
+            boxes = boxlists[i].bbox
+            boxlist = boxlists[i]
+            result = []
+            # skip the background
+            for j in range(1, self.num_classes):
+                inds = (labels == j).nonzero().view(-1)
+
+                scores_j = scores[inds]
+                boxes_j = boxes[inds, :].view(-1, 4)
+                boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
+                boxlist_for_class.add_field("scores", scores_j)
+                boxlist_for_class = boxlist_nms(
+                    boxlist_for_class, self.nms_thresh,
+                    score_field="scores"
+                )
+                num_labels = len(boxlist_for_class)
+                boxlist_for_class.add_field(
+                    "labels", torch.full((num_labels,), j,
+                                         dtype=torch.int64,
+                                         device=scores.device)
+                )
+                result.append(boxlist_for_class)
+
+            result = cat_boxlist(result)
+            number_of_detections = len(result)
+
+            # Limit to max_per_image detections **over all classes**
+            if number_of_detections > self.fpn_post_nms_top_n > 0:
+                cls_scores = result.get_field("scores")
+                image_thresh, _ = torch.kthvalue(
+                    cls_scores.cpu(),
+                    number_of_detections - self.fpn_post_nms_top_n + 1
+                )
+                keep = cls_scores >= image_thresh.item()
+                keep = torch.nonzero(keep).squeeze(1)
+                result = result[keep]
+            results.append(result)
+        return results
+
+    def forward(self, anchors, objectness, box_regression, targets=None):
+        """
+        Arguments:
+            anchors: list[list[BoxList]]
+            objectness: list[tensor]
+            box_regression: list[tensor]
+
+        Returns:
+            boxlists (list[BoxList]): the post-processed anchors, after
+                applying box decoding and NMS
+        """
+        sampled_boxes = []
+        anchors = list(zip(*anchors))
+        for a, o, b in zip(anchors, objectness, box_regression):
+            sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
+
+        boxlists = list(zip(*sampled_boxes))
+        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
+
+        boxlists = self.select_over_all_levels(boxlists)
+
+        return boxlists
+
+
+def make_retina_postprocessor(config, rpn_box_coder, is_train):
+    pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH
+    pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N
+    nms_thresh = config.MODEL.RETINANET.NMS_TH
+    fpn_post_nms_top_n = config.MODEL.RETINANET.DETECTIONS_PER_IMG
+    min_size = 0
+
+    box_selector = RetinaPostProcessor(
+        pre_nms_thresh=pre_nms_thresh,
+        pre_nms_top_n=pre_nms_top_n,
+        nms_thresh=nms_thresh,
+        fpn_post_nms_top_n=fpn_post_nms_top_n,
+        min_size=min_size,
+        num_classes=config.MODEL.RETINANET.NUM_CLASSES,
+        box_coder=rpn_box_coder,
+    )
+
+    return box_selector
+
+
+class FCOSPostProcessor(torch.nn.Module):
+    """
+    Performs post-processing on the outputs of the RetinaNet boxes.
+    This is only used in the testing.
+    """
+
+    def __init__(
+            self,
+            pre_nms_thresh,
+            pre_nms_top_n,
+            nms_thresh,
+            fpn_post_nms_top_n,
+            min_size,
+            num_classes,
+            bbox_aug_enabled=False
+    ):
+        """
+        Arguments:
+            pre_nms_thresh (float)
+            pre_nms_top_n (int)
+            nms_thresh (float)
+            fpn_post_nms_top_n (int)
+            min_size (int)
+            num_classes (int)
+            box_coder (BoxCoder)
+        """
+        super(FCOSPostProcessor, self).__init__()
+        self.pre_nms_thresh = pre_nms_thresh
+        self.pre_nms_top_n = pre_nms_top_n
+        self.nms_thresh = nms_thresh
+        self.fpn_post_nms_top_n = fpn_post_nms_top_n
+        self.min_size = min_size
+        self.num_classes = num_classes
+        self.bbox_aug_enabled = bbox_aug_enabled
+
+    def forward_for_single_feature_map(
+            self, locations, box_cls,
+            box_regression, centerness,
+            image_sizes):
+        """
+        Arguments:
+            anchors: list[BoxList]
+            box_cls: tensor of size N, A * C, H, W
+            box_regression: tensor of size N, A * 4, H, W
+        """
+        N, C, H, W = box_cls.shape
+
+        # put in the same format as locations
+        box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
+        box_cls = box_cls.reshape(N, -1, C).sigmoid()
+        box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1)
+        box_regression = box_regression.reshape(N, -1, 4)
+        centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1)
+        centerness = centerness.reshape(N, -1).sigmoid()
+
+        candidate_inds = box_cls > self.pre_nms_thresh
+        pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1)
+        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
+
+        # multiply the classification scores with centerness scores
+        box_cls = box_cls * centerness[:, :, None]
+
+        results = []
+        for i in range(N):
+            per_box_cls = box_cls[i]
+            per_candidate_inds = candidate_inds[i]
+            per_box_cls = per_box_cls[per_candidate_inds]
+
+            per_candidate_nonzeros = per_candidate_inds.nonzero()
+            per_box_loc = per_candidate_nonzeros[:, 0]
+            per_class = per_candidate_nonzeros[:, 1] + 1
+
+            per_box_regression = box_regression[i]
+            per_box_regression = per_box_regression[per_box_loc]
+            per_locations = locations[per_box_loc]
+
+            per_pre_nms_top_n = pre_nms_top_n[i]
+
+            if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
+                per_box_cls, top_k_indices = \
+                    per_box_cls.topk(per_pre_nms_top_n, sorted=False)
+                per_class = per_class[top_k_indices]
+                per_box_regression = per_box_regression[top_k_indices]
+                per_locations = per_locations[top_k_indices]
+
+            detections = torch.stack([
+                per_locations[:, 0] - per_box_regression[:, 0],
+                per_locations[:, 1] - per_box_regression[:, 1],
+                per_locations[:, 0] + per_box_regression[:, 2],
+                per_locations[:, 1] + per_box_regression[:, 3],
+            ], dim=1)
+
+            h, w = image_sizes[i]
+            boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy")
+            boxlist.add_field('centers', per_locations)
+            boxlist.add_field("labels", per_class)
+            boxlist.add_field("scores", torch.sqrt(per_box_cls))
+            boxlist = boxlist.clip_to_image(remove_empty=False)
+            boxlist = remove_small_boxes(boxlist, self.min_size)
+            results.append(boxlist)
+
+        return results
+
+    def forward(self, locations, box_cls, box_regression, centerness, image_sizes):
+        """
+        Arguments:
+            anchors: list[list[BoxList]]
+            box_cls: list[tensor]
+            box_regression: list[tensor]
+            image_sizes: list[(h, w)]
+        Returns:
+            boxlists (list[BoxList]): the post-processed anchors, after
+                applying box decoding and NMS
+        """
+        sampled_boxes = []
+        for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)):
+            sampled_boxes.append(
+                self.forward_for_single_feature_map(
+                    l, o, b, c, image_sizes
+                )
+            )
+
+        boxlists = list(zip(*sampled_boxes))
+        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
+        if not self.bbox_aug_enabled:
+            boxlists = self.select_over_all_levels(boxlists)
+
+        return boxlists
+
+    # TODO very similar to filter_results from PostProcessor
+    # but filter_results is per image
+    # TODO Yang: solve this issue in the future. No good solution
+    # right now.
+    def select_over_all_levels(self, boxlists):
+        num_images = len(boxlists)
+        results = []
+        for i in range(num_images):
+            # multiclass nms
+            result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
+            number_of_detections = len(result)
+
+            # Limit to max_per_image detections **over all classes**
+            if number_of_detections > self.fpn_post_nms_top_n > 0:
+                cls_scores = result.get_field("scores")
+                image_thresh, _ = torch.kthvalue(
+                    cls_scores.cpu(),
+                    number_of_detections - self.fpn_post_nms_top_n + 1
+                )
+                keep = cls_scores >= image_thresh.item()
+                keep = torch.nonzero(keep).squeeze(1)
+                result = result[keep]
+            results.append(result)
+        return results
+
+
+def make_fcos_postprocessor(config, is_train=False):
+    pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH
+    if is_train:
+        pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH_TRAIN
+    pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N
+    fpn_post_nms_top_n = config.MODEL.FCOS.DETECTIONS_PER_IMG
+    if is_train:
+        pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN
+        fpn_post_nms_top_n = config.MODEL.FCOS.POST_NMS_TOP_N_TRAIN
+    nms_thresh = config.MODEL.FCOS.NMS_TH
+
+    box_selector = FCOSPostProcessor(
+        pre_nms_thresh=pre_nms_thresh,
+        pre_nms_top_n=pre_nms_top_n,
+        nms_thresh=nms_thresh,
+        fpn_post_nms_top_n=fpn_post_nms_top_n,
+        min_size=0,
+        num_classes=config.MODEL.FCOS.NUM_CLASSES,
+    )
+
+    return box_selector
+
+
+class ATSSPostProcessor(torch.nn.Module):
+    def __init__(
+            self,
+            pre_nms_thresh,
+            pre_nms_top_n,
+            nms_thresh,
+            fpn_post_nms_top_n,
+            min_size,
+            num_classes,
+            box_coder,
+            bbox_aug_enabled=False,
+            bbox_aug_vote=False,
+            score_agg='MEAN',
+            mdetr_style_aggregate_class_num=-1
+    ):
+        super(ATSSPostProcessor, self).__init__()
+        self.pre_nms_thresh = pre_nms_thresh
+        self.pre_nms_top_n = pre_nms_top_n
+        self.nms_thresh = nms_thresh
+        self.fpn_post_nms_top_n = fpn_post_nms_top_n
+        self.min_size = min_size
+        self.num_classes = num_classes
+        self.bbox_aug_enabled = bbox_aug_enabled
+        self.box_coder = box_coder
+        self.bbox_aug_vote = bbox_aug_vote
+        self.score_agg = score_agg
+        self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num
+
+    def forward_for_single_feature_map(self, box_regression, centerness, anchors,
+                                       box_cls=None,
+                                       token_logits=None,
+                                       dot_product_logits=None,
+                                       positive_map=None,
+                                       ):
+
+        N, _, H, W = box_regression.shape
+
+        A = box_regression.size(1) // 4
+
+        if box_cls is not None:
+            C = box_cls.size(1) // A
+
+        if token_logits is not None:
+            T = token_logits.size(1) // A
+
+        # put in the same format as anchors
+        if box_cls is not None:
+            #print('Classification.')
+            box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
+            box_cls = box_cls.sigmoid()
+
+        # binary focal loss version
+        if token_logits is not None:
+            #print('Token.')
+            token_logits = permute_and_flatten(token_logits, N, A, T, H, W)
+            token_logits = token_logits.sigmoid()
+            # turn back to original classes
+            scores = convert_grounding_to_od_logits(logits=token_logits, box_cls=box_cls, positive_map=positive_map,
+                                                    score_agg=self.score_agg)
+            box_cls = scores
+
+        # binary dot product focal version
+        if dot_product_logits is not None:
+            #print('Dot Product.')
+            dot_product_logits = dot_product_logits.sigmoid()
+            if self.mdetr_style_aggregate_class_num != -1:
+                scores = convert_grounding_to_od_logits_v2(
+                    logits=dot_product_logits,
+                    num_class=self.mdetr_style_aggregate_class_num,
+                    positive_map=positive_map,
+                    score_agg=self.score_agg,
+                    disable_minus_one=False)
+            else:
+                scores = convert_grounding_to_od_logits(logits=dot_product_logits, box_cls=box_cls,
+                                                        positive_map=positive_map,
+                                                        score_agg=self.score_agg)
+            box_cls = scores
+
+        box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
+        box_regression = box_regression.reshape(N, -1, 4)
+
+        candidate_inds = box_cls > self.pre_nms_thresh
+        pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1)
+        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
+
+        centerness = permute_and_flatten(centerness, N, A, 1, H, W)
+        centerness = centerness.reshape(N, -1).sigmoid()
+
+        # multiply the classification scores with centerness scores
+
+        box_cls = box_cls * centerness[:, :, None]
+
+        results = []
+
+        for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors \
+                in zip(box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors):
+            per_box_cls = per_box_cls[per_candidate_inds]
+
+            per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False)
+
+            per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :]
+
+            per_box_loc = per_candidate_nonzeros[:, 0]
+            per_class = per_candidate_nonzeros[:, 1] + 1
+
+            # print(per_class)
+
+            detections = self.box_coder.decode(
+                per_box_regression[per_box_loc, :].view(-1, 4),
+                per_anchors.bbox[per_box_loc, :].view(-1, 4)
+            )
+
+            boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
+            boxlist.add_field("labels", per_class)
+            boxlist.add_field("scores", torch.sqrt(per_box_cls))
+            boxlist = boxlist.clip_to_image(remove_empty=False)
+            boxlist = remove_small_boxes(boxlist, self.min_size)
+            results.append(boxlist)
+
+        return results
+
+    def forward(self, box_regression, centerness, anchors,
+                box_cls=None,
+                token_logits=None,
+                dot_product_logits=None,
+                positive_map=None,
+                ):
+        sampled_boxes = []
+        anchors = list(zip(*anchors))
+        for idx, (b, c, a) in enumerate(zip(box_regression, centerness, anchors)):
+            o = None
+            t = None
+            d = None
+            if box_cls is not None:
+                o = box_cls[idx]
+            if token_logits is not None:
+                t = token_logits[idx]
+            if dot_product_logits is not None:
+                d = dot_product_logits[idx]
+
+            sampled_boxes.append(
+                self.forward_for_single_feature_map(b, c, a, o, t, d, positive_map)
+            )
+
+        boxlists = list(zip(*sampled_boxes))
+        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
+        if not (self.bbox_aug_enabled and not self.bbox_aug_vote):
+            boxlists = self.select_over_all_levels(boxlists)
+
+        return boxlists
+
+    # TODO very similar to filter_results from PostProcessor
+    # but filter_results is per image
+    # TODO Yang: solve this issue in the future. No good solution
+    # right now.
+    def select_over_all_levels(self, boxlists):
+        num_images = len(boxlists)
+        results = []
+        for i in range(num_images):
+            # multiclass nms
+            result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
+            number_of_detections = len(result)
+
+            # Limit to max_per_image detections **over all classes**
+            if number_of_detections > self.fpn_post_nms_top_n > 0:
+                cls_scores = result.get_field("scores")
+                image_thresh, _ = torch.kthvalue(
+                    # TODO: confirm with Pengchuan and Xiyang, torch.kthvalue is not implemented for 'Half'
+                    # cls_scores.cpu(),
+                    cls_scores.cpu().float(),
+                    number_of_detections - self.fpn_post_nms_top_n + 1
+                )
+                keep = cls_scores >= image_thresh.item()
+                keep = torch.nonzero(keep).squeeze(1)
+                result = result[keep]
+            results.append(result)
+        return results
+
+
+def convert_grounding_to_od_logits(logits, box_cls, positive_map, score_agg=None):
+    scores = torch.zeros(logits.shape[0], logits.shape[1], box_cls.shape[2]).to(logits.device)
+    # 256 -> 80, average for each class
+    if positive_map is not None:
+        # score aggregation method
+        if score_agg == "MEAN":
+            for label_j in positive_map:
+                scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].mean(-1)
+        elif score_agg == "MAX":
+            # torch.max() returns (values, indices)
+            for label_j in positive_map:
+                scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[
+                    0]
+        elif score_agg == "ONEHOT":
+            # one hot
+            scores = logits[:, :, :len(positive_map)]
+        else:
+            raise NotImplementedError
+    return scores
+
+
+def convert_grounding_to_od_logits_v2(logits, num_class, positive_map, score_agg=None, disable_minus_one = True):
+    
+    scores = torch.zeros(logits.shape[0], logits.shape[1], num_class).to(logits.device)
+    # 256 -> 80, average for each class
+    if positive_map is not None:
+        # score aggregation method
+        if score_agg == "MEAN":
+            for label_j in positive_map:
+                locations_label_j = positive_map[label_j]
+                if isinstance(locations_label_j, int):
+                    locations_label_j = [locations_label_j]
+                scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(locations_label_j)].mean(-1)
+        elif score_agg == "POWER":
+            for label_j in positive_map:
+                locations_label_j = positive_map[label_j]
+                if isinstance(locations_label_j, int):
+                    locations_label_j = [locations_label_j]
+
+                probability = torch.prod(logits[:, :, torch.LongTensor(locations_label_j)], dim=-1).squeeze(-1)
+                probability = torch.pow(probability, 1/len(locations_label_j))
+                scores[:, :, label_j if disable_minus_one else label_j - 1] = probability
+        elif score_agg == "MAX":
+            # torch.max() returns (values, indices)
+            for label_j in positive_map:
+                scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[
+                    0]
+        elif score_agg == "ONEHOT":
+            # one hot
+            scores = logits[:, :, :len(positive_map)]
+        else:
+            raise NotImplementedError
+    return scores
+
+def make_atss_postprocessor(config, box_coder, is_train=False):
+    pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH
+    if is_train:
+        pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH_TRAIN
+    pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N
+    fpn_post_nms_top_n = config.MODEL.ATSS.DETECTIONS_PER_IMG
+    if is_train:
+        pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN
+        fpn_post_nms_top_n = config.MODEL.ATSS.POST_NMS_TOP_N_TRAIN
+    nms_thresh = config.MODEL.ATSS.NMS_TH
+    score_agg = config.MODEL.DYHEAD.SCORE_AGG
+
+    box_selector = ATSSPostProcessor(
+        pre_nms_thresh=pre_nms_thresh,
+        pre_nms_top_n=pre_nms_top_n,
+        nms_thresh=nms_thresh,
+        fpn_post_nms_top_n=fpn_post_nms_top_n,
+        min_size=0,
+        num_classes=config.MODEL.ATSS.NUM_CLASSES,
+        box_coder=box_coder,
+        bbox_aug_enabled=config.TEST.USE_MULTISCALE,
+        score_agg=score_agg,
+        mdetr_style_aggregate_class_num=config.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM
+    )
+
+    return box_selector
diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..097fdc5ffd0a864110a26f6cbb0ee54b6af38f2b
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/loss.py
@@ -0,0 +1,1251 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+This file contains specific functions for computing losses on the RPN
+file
+"""
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
+from ..utils import cat, concat_box_prediction_layers
+
+from maskrcnn_benchmark.layers import smooth_l1_loss
+from maskrcnn_benchmark.modeling.matcher import Matcher
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.layers import SigmoidFocalLoss, IOULoss, TokenSigmoidFocalLoss
+from maskrcnn_benchmark.utils.comm import get_world_size, reduce_sum
+from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
+from maskrcnn_benchmark.utils.shallow_contrastive_loss_helper import *
+
+from transformers import AutoTokenizer
+
+INF = 1e8
+
+
+class RPNLossComputation(object):
+    """
+    This class computes the RPN loss.
+    """
+
+    def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
+        """
+        Arguments:
+            proposal_matcher (Matcher)
+            fg_bg_sampler (BalancedPositiveNegativeSampler)
+            box_coder (BoxCoder)
+        """
+        # self.target_preparator = target_preparator
+        self.proposal_matcher = proposal_matcher
+        self.fg_bg_sampler = fg_bg_sampler
+        self.box_coder = box_coder
+
+    def match_targets_to_anchors(self, anchor, target):
+        match_quality_matrix = boxlist_iou(target, anchor)
+        matched_idxs = self.proposal_matcher(match_quality_matrix)
+        # RPN doesn't need any fields from target
+        # for creating the labels, so clear them all
+        target = target.copy_with_fields([])
+        # get the targets corresponding GT for each anchor
+        # NB: need to clamp the indices because we can have a single
+        # GT in the image, and matched_idxs can be -2, which goes
+        # out of bounds
+
+        if len(target):
+            matched_targets = target[matched_idxs.clamp(min=0)]
+        else:
+            matched_targets = target
+
+        matched_targets.add_field("matched_idxs", matched_idxs)
+        return matched_targets
+
+    def prepare_targets(self, anchors, targets):
+        labels = []
+        regression_targets = []
+        for anchors_per_image, targets_per_image in zip(anchors, targets):
+            matched_targets = self.match_targets_to_anchors(
+                anchors_per_image, targets_per_image
+            )
+
+            matched_idxs = matched_targets.get_field("matched_idxs")
+            labels_per_image = matched_idxs >= 0
+            labels_per_image = labels_per_image.to(dtype=torch.float32)
+            # discard anchors that go out of the boundaries of the image
+            labels_per_image[~anchors_per_image.get_field("visibility")] = -1
+
+            # discard indices that are between thresholds
+            inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
+            labels_per_image[inds_to_discard] = -1
+
+            # compute regression targets
+            if not matched_targets.bbox.shape[0]:
+                zeros = torch.zeros_like(labels_per_image)
+                regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+            else:
+                regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, anchors_per_image.bbox)
+
+            labels.append(labels_per_image)
+            regression_targets.append(regression_targets_per_image)
+
+        return labels, regression_targets
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def __call__(self, anchors, objectness, box_regression, targets):
+        """
+        Arguments:
+            anchors (list[BoxList])
+            objectness (list[Tensor])
+            box_regression (list[Tensor])
+            targets (list[BoxList])
+
+        Returns:
+            objectness_loss (Tensor)
+            box_loss (Tensor
+        """
+        anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
+        labels, regression_targets = self.prepare_targets(anchors, targets)
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
+        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
+
+        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
+
+        objectness_flattened = []
+        box_regression_flattened = []
+        # for each feature level, permute the outputs to make them be in the
+        # same format as the labels. Note that the labels are computed for
+        # all feature levels concatenated, so we keep the same representation
+        # for the objectness and the box_regression
+        for objectness_per_level, box_regression_per_level in zip(
+                objectness, box_regression
+        ):
+            N, A, H, W = objectness_per_level.shape
+            objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(
+                N, -1
+            )
+            box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
+            box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
+            box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
+            objectness_flattened.append(objectness_per_level)
+            box_regression_flattened.append(box_regression_per_level)
+        # concatenate on the first dimension (representing the feature levels), to
+        # take into account the way the labels were generated (with all feature maps
+        # being concatenated as well)
+        objectness = cat(objectness_flattened, dim=1).reshape(-1)
+        box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
+
+        labels = torch.cat(labels, dim=0)
+        regression_targets = torch.cat(regression_targets, dim=0)
+
+        box_loss = smooth_l1_loss(
+            box_regression[sampled_pos_inds],
+            regression_targets[sampled_pos_inds],
+            beta=1.0 / 9,
+            size_average=False,
+        ) / (sampled_inds.numel())
+
+        objectness_loss = F.binary_cross_entropy_with_logits(
+            objectness[sampled_inds], labels[sampled_inds]
+        )
+
+        return objectness_loss, box_loss
+
+
+class FocalLossComputation(object):
+    """
+    This class computes the RetinaNet loss.
+    """
+
+    def __init__(self, proposal_matcher, box_coder,
+                 generate_labels_func,
+                 sigmoid_focal_loss,
+                 bbox_reg_beta=0.11,
+                 regress_norm=1.0):
+        """
+        Arguments:
+            proposal_matcher (Matcher)
+            box_coder (BoxCoder)
+        """
+        self.proposal_matcher = proposal_matcher
+        self.box_coder = box_coder
+        self.box_cls_loss_func = sigmoid_focal_loss
+        self.bbox_reg_beta = bbox_reg_beta
+        self.copied_fields = ['labels']
+        self.generate_labels_func = generate_labels_func
+        self.discard_cases = ['between_thresholds']
+        self.regress_norm = regress_norm
+
+    def match_targets_to_anchors(self, anchor, target, copied_fields=[]):
+        match_quality_matrix = boxlist_iou(target, anchor)
+        matched_idxs = self.proposal_matcher(match_quality_matrix)
+        # RPN doesn't need any fields from target
+        # for creating the labels, so clear them all
+        target = target.copy_with_fields(copied_fields)
+        # get the targets corresponding GT for each anchor
+        # NB: need to clamp the indices because we can have a single
+        # GT in the image, and matched_idxs can be -2, which goes
+        # out of bounds
+        matched_targets = target[matched_idxs.clamp(min=0)]
+        matched_targets.add_field("matched_idxs", matched_idxs)
+        return matched_targets
+
+    def prepare_targets(self, anchors, targets):
+        labels = []
+        regression_targets = []
+        for anchors_per_image, targets_per_image in zip(anchors, targets):
+            matched_targets = self.match_targets_to_anchors(
+                anchors_per_image, targets_per_image, self.copied_fields
+            )
+
+            matched_idxs = matched_targets.get_field("matched_idxs")
+            labels_per_image = self.generate_labels_func(matched_targets)
+            labels_per_image = labels_per_image.to(dtype=torch.float32)
+
+            # Background (negative examples)
+            bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
+            labels_per_image[bg_indices] = 0
+
+            # discard anchors that go out of the boundaries of the image
+            if "not_visibility" in self.discard_cases:
+                labels_per_image[~anchors_per_image.get_field("visibility")] = -1
+
+            # discard indices that are between thresholds
+            if "between_thresholds" in self.discard_cases:
+                inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
+                labels_per_image[inds_to_discard] = -1
+
+            # compute regression targets
+            regression_targets_per_image = self.box_coder.encode(
+                matched_targets.bbox, anchors_per_image.bbox
+            )
+
+            labels.append(labels_per_image)
+            regression_targets.append(regression_targets_per_image)
+
+        return labels, regression_targets
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def __call__(self, anchors, box_cls, box_regression, targets):
+        """
+        Arguments:
+            anchors (list[BoxList])
+            box_cls (list[Tensor])
+            box_regression (list[Tensor])
+            targets (list[BoxList])
+
+        Returns:
+            retinanet_cls_loss (Tensor)
+            retinanet_regression_loss (Tensor
+        """
+        anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
+        labels, regression_targets = self.prepare_targets(anchors, targets)
+
+        N = len(labels)
+        box_cls, box_regression = \
+            concat_box_prediction_layers(box_cls, box_regression)
+
+        labels = torch.cat(labels, dim=0)
+        regression_targets = torch.cat(regression_targets, dim=0)
+        pos_inds = torch.nonzero(labels > 0).squeeze(1)
+
+        retinanet_regression_loss = smooth_l1_loss(
+            box_regression[pos_inds],
+            regression_targets[pos_inds],
+            beta=self.bbox_reg_beta,
+            size_average=False,
+        ) / (max(1, pos_inds.numel() * self.regress_norm))
+
+        labels = labels.int()
+
+        retinanet_cls_loss = self.box_cls_loss_func(
+            box_cls,
+            labels
+        ) / (pos_inds.numel() + N)
+
+        return retinanet_cls_loss, retinanet_regression_loss
+
+
+class FCOSLossComputation(object):
+    """
+    This class computes the FCOS losses.
+    """
+
+    def __init__(self, cfg):
+        self.cls_loss_func = SigmoidFocalLoss(
+            cfg.MODEL.FOCAL.LOSS_GAMMA,
+            cfg.MODEL.FOCAL.LOSS_ALPHA
+        )
+        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
+        self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS
+        self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE
+        self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
+        self.use_gt_center = cfg.MODEL.FCOS.USE_GT_CENTER
+
+        # we make use of IOU Loss for bounding boxes regression,
+        # but we found that L1 in log scale can yield a similar performance
+        self.box_reg_loss_func = IOULoss(self.iou_loss_type)
+        self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum")
+
+    def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):
+        '''
+        This code is from
+        https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
+        maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
+        '''
+        num_gts = gt.shape[0]
+        K = len(gt_xs)
+        gt = gt[None].expand(K, num_gts, 4)
+        center_x = (gt[..., 0] + gt[..., 2]) / 2
+        center_y = (gt[..., 1] + gt[..., 3]) / 2
+        center_gt = gt.new_zeros(gt.shape)
+        # no gt
+        if center_x[..., 0].sum() == 0:
+            return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)
+        beg = 0
+        for level, n_p in enumerate(num_points_per):
+            end = beg + n_p
+            stride = strides[level] * radius
+            xmin = center_x[beg:end] - stride
+            ymin = center_y[beg:end] - stride
+            xmax = center_x[beg:end] + stride
+            ymax = center_y[beg:end] + stride
+            # limit sample region in gt
+            center_gt[beg:end, :, 0] = torch.where(
+                xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0]
+            )
+            center_gt[beg:end, :, 1] = torch.where(
+                ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1]
+            )
+            center_gt[beg:end, :, 2] = torch.where(
+                xmax > gt[beg:end, :, 2],
+                gt[beg:end, :, 2], xmax
+            )
+            center_gt[beg:end, :, 3] = torch.where(
+                ymax > gt[beg:end, :, 3],
+                gt[beg:end, :, 3], ymax
+            )
+            beg = end
+        left = gt_xs[:, None] - center_gt[..., 0]
+        right = center_gt[..., 2] - gt_xs[:, None]
+        top = gt_ys[:, None] - center_gt[..., 1]
+        bottom = center_gt[..., 3] - gt_ys[:, None]
+        center_bbox = torch.stack((left, top, right, bottom), -1)
+        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
+        return inside_gt_bbox_mask
+
+    def prepare_targets(self, points, targets):
+        object_sizes_of_interest = [
+            [-1, 64],
+            [64, 128],
+            [128, 256],
+            [256, 512],
+            [512, INF],
+        ]
+        expanded_object_sizes_of_interest = []
+        for l, points_per_level in enumerate(points):
+            object_sizes_of_interest_per_level = \
+                points_per_level.new_tensor(object_sizes_of_interest[l])
+            expanded_object_sizes_of_interest.append(
+                object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1)
+            )
+
+        expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0)
+        num_points_per_level = [len(points_per_level) for points_per_level in points]
+        self.num_points_per_level = num_points_per_level
+        points_all_level = torch.cat(points, dim=0)
+        labels, reg_targets = self.compute_targets_for_locations(
+            points_all_level, targets, expanded_object_sizes_of_interest
+        )
+
+        for i in range(len(labels)):
+            labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
+            reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0)
+
+        labels_level_first = []
+        reg_targets_level_first = []
+        for level in range(len(points)):
+            labels_level_first.append(
+                torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0)
+            )
+
+            reg_targets_per_level = torch.cat([
+                reg_targets_per_im[level]
+                for reg_targets_per_im in reg_targets
+            ], dim=0)
+
+            if self.norm_reg_targets:
+                reg_targets_per_level = reg_targets_per_level / self.fpn_strides[level]
+            reg_targets_level_first.append(reg_targets_per_level)
+
+        return labels_level_first, reg_targets_level_first
+
+    def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest):
+        labels = []
+        reg_targets = []
+        xs, ys = locations[:, 0], locations[:, 1]
+
+        for im_i in range(len(targets)):
+            targets_per_im = targets[im_i]
+            assert targets_per_im.mode == "xyxy"
+
+            if self.use_gt_center:
+                center = targets_per_im.get_field("cbox")
+                bboxes = center.bbox
+                area = center.area()
+            else:
+                bboxes = targets_per_im.bbox
+                area = targets_per_im.area()
+            labels_per_im = targets_per_im.get_field("labels")
+
+            l = xs[:, None] - bboxes[:, 0][None]
+            t = ys[:, None] - bboxes[:, 1][None]
+            r = bboxes[:, 2][None] - xs[:, None]
+            b = bboxes[:, 3][None] - ys[:, None]
+            reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
+
+            if self.center_sampling_radius > 0:
+                is_in_boxes = self.get_sample_region(
+                    bboxes,
+                    self.fpn_strides,
+                    self.num_points_per_level,
+                    xs, ys,
+                    radius=self.center_sampling_radius
+                )
+            else:
+                # no center sampling, it will use all the locations within a ground-truth box
+                is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0
+
+            max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
+            # limit the regression range for each location
+            is_cared_in_the_level = \
+                (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \
+                (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]])
+
+            locations_to_gt_area = area[None].repeat(len(locations), 1)
+            locations_to_gt_area[is_in_boxes == 0] = INF
+            locations_to_gt_area[is_cared_in_the_level == 0] = INF
+
+            # if there are still more than one objects for a location,
+            # we choose the one with minimal area
+            locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1)
+
+            reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds]
+            labels_per_im = labels_per_im[locations_to_gt_inds]
+            labels_per_im[locations_to_min_area == INF] = 0
+
+            labels.append(labels_per_im)
+            reg_targets.append(reg_targets_per_im)
+
+        return labels, reg_targets
+
+    def compute_centerness_targets(self, reg_targets):
+        left_right = reg_targets[:, [0, 2]]
+        top_bottom = reg_targets[:, [1, 3]]
+        centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
+                     (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+        return torch.sqrt(centerness)
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def __call__(self, locations, box_cls, box_regression, centerness, targets):
+        """
+        Arguments:
+            locations (list[BoxList])
+            box_cls (list[Tensor])
+            box_regression (list[Tensor])
+            centerness (list[Tensor])
+            targets (list[BoxList])
+
+        Returns:
+            cls_loss (Tensor)
+            reg_loss (Tensor)
+            centerness_loss (Tensor)
+        """
+        N = box_cls[0].size(0)
+        num_classes = box_cls[0].size(1)
+        labels, reg_targets = self.prepare_targets(locations, targets)
+
+        box_cls_flatten = []
+        box_regression_flatten = []
+        centerness_flatten = []
+        labels_flatten = []
+        reg_targets_flatten = []
+        for l in range(len(labels)):
+            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
+            box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
+            labels_flatten.append(labels[l].reshape(-1))
+            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
+            centerness_flatten.append(centerness[l].reshape(-1))
+
+        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
+        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
+        centerness_flatten = torch.cat(centerness_flatten, dim=0)
+        labels_flatten = torch.cat(labels_flatten, dim=0)
+        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
+
+        pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
+
+        box_regression_flatten = box_regression_flatten[pos_inds]
+        reg_targets_flatten = reg_targets_flatten[pos_inds]
+        centerness_flatten = centerness_flatten[pos_inds]
+
+        cls_loss = self.cls_loss_func(
+            box_cls_flatten,
+            labels_flatten.int()
+        ) / max(pos_inds.numel(), 1.0)
+
+        if pos_inds.numel() > 0:
+            centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
+
+            reg_loss = self.box_reg_loss_func(
+                box_regression_flatten,
+                reg_targets_flatten,
+                centerness_targets
+            ) / centerness_targets.sum()
+            centerness_loss = self.centerness_loss_func(
+                centerness_flatten,
+                centerness_targets
+            ) / max(pos_inds.numel(), 1.0)
+        else:
+            reg_loss = box_regression_flatten.sum()
+            centerness_loss = centerness_flatten.sum()
+
+        return cls_loss, reg_loss, centerness_loss
+
+
+# class ATSSLossComputation(object):
+class ATSSLossComputation(torch.nn.Module):
+
+    def __init__(self, cfg, box_coder):
+        super(ATSSLossComputation, self).__init__()
+        
+        self.cfg = cfg
+        self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA)
+        self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum")
+        self.matcher = Matcher(cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, True)
+        self.box_coder = box_coder
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            self.token_loss_func = TokenSigmoidFocalLoss(cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA,
+                                                         cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA)
+
+        self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
+
+        # self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
+        if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+            from transformers import CLIPTokenizerFast
+            # self.tokenizer = build_tokenizer(self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
+            if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+                print("Reuse token 'ðŁĴij</w>' (token_id = 49404) for mask token!")
+                self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                            from_slow=True, mask_token='ðŁĴij</w>')
+            else:
+                self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
+                                                                            from_slow=True)
+        else:
+            self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
+
+        # if use shallow contrastive loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \
+                or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
+                assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS == False
+                channels = cfg.MODEL.DYHEAD.CHANNELS
+                num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
+                shallow_input_dim = channels * num_anchors
+            elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+                assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS == False
+                shallow_input_dim = cfg.MODEL.SWINT.OUT_CHANNELS[-2]
+
+            shallow_log_scale = self.cfg.MODEL.DYHEAD.SHALLOW_LOG_SCALE
+            shallow_contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM
+            # self.shallow_contrastive_projection_image = nn.Conv2d(channels, num_anchors * shallow_contrastive_hdim,
+            #                                                       kernel_size=1)
+            self.shallow_contrastive_projection_image = nn.Linear(shallow_input_dim, shallow_contrastive_hdim,
+                                                                  bias=True)
+            self.shallow_contrastive_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM,
+                                                                 shallow_contrastive_hdim, bias=True)
+            self.shallow_log_scale = nn.Parameter(torch.Tensor([shallow_log_scale]), requires_grad=True)
+
+        # (initialization) if use shallow contrastive loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
+            for modules in [self.shallow_contrastive_projection_image, self.shallow_contrastive_projection_text]:
+                for l in modules.modules():
+                    if isinstance(l, nn.Conv2d):
+                        torch.nn.init.normal_(l.weight, std=0.01)
+                        torch.nn.init.constant_(l.bias, 0)
+                    if isinstance(l, nn.Linear):
+                        torch.nn.init.xavier_uniform_(l.weight)
+                        l.bias.data.fill_(0)
+
+    def NllSoftMaxLoss(self, logits, target):
+        loss_ce = -target * logits.log_softmax(
+            -1)  # basically, only the those positives with positive target_sim will have losses
+        return loss_ce
+
+    def ContrastiveAlignLoss(self, logits, positive_map):
+        positive_logits = -logits.masked_fill(~positive_map, 0)
+        negative_logits = logits  # .masked_fill(positive_map, -1000000)
+
+        boxes_with_pos = positive_map.any(2)
+        pos_term = positive_logits.sum(2)
+        neg_term = negative_logits.logsumexp(2)
+
+        nb_pos = positive_map.sum(2) + 1e-6
+
+        box_to_token_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~boxes_with_pos, 0).sum()
+
+        tokens_with_pos = positive_map.any(1)
+        pos_term = positive_logits.sum(1)
+        neg_term = negative_logits.logsumexp(1)
+
+        nb_pos = positive_map.sum(1) + 1e-6
+
+        tokens_to_boxes_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~tokens_with_pos, 0).sum()
+        tot_loss = (box_to_token_loss + tokens_to_boxes_loss) / 2
+
+        return tot_loss
+
+    def GIoULoss(self, pred, target, anchor, weight=None):
+        pred_boxes = self.box_coder.decode(pred.view(-1, 4), anchor.view(-1, 4))
+        pred_x1 = pred_boxes[:, 0]
+        pred_y1 = pred_boxes[:, 1]
+        pred_x2 = pred_boxes[:, 2]
+        pred_y2 = pred_boxes[:, 3]
+        pred_x2 = torch.max(pred_x1, pred_x2)
+        pred_y2 = torch.max(pred_y1, pred_y2)
+        pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
+
+        gt_boxes = self.box_coder.decode(target.view(-1, 4), anchor.view(-1, 4))
+        target_x1 = gt_boxes[:, 0]
+        target_y1 = gt_boxes[:, 1]
+        target_x2 = gt_boxes[:, 2]
+        target_y2 = gt_boxes[:, 3]
+        target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
+
+        x1_intersect = torch.max(pred_x1, target_x1)
+        y1_intersect = torch.max(pred_y1, target_y1)
+        x2_intersect = torch.min(pred_x2, target_x2)
+        y2_intersect = torch.min(pred_y2, target_y2)
+        area_intersect = torch.zeros(pred_x1.size()).to(pred)
+        mask = (y2_intersect > y1_intersect) * (x2_intersect > x1_intersect)
+        area_intersect[mask] = (x2_intersect[mask] - x1_intersect[mask]) * (y2_intersect[mask] - y1_intersect[mask])
+
+        x1_enclosing = torch.min(pred_x1, target_x1)
+        y1_enclosing = torch.min(pred_y1, target_y1)
+        x2_enclosing = torch.max(pred_x2, target_x2)
+        y2_enclosing = torch.max(pred_y2, target_y2)
+        area_enclosing = (x2_enclosing - x1_enclosing) * (y2_enclosing - y1_enclosing) + 1e-7
+
+        area_union = pred_area + target_area - area_intersect + 1e-7
+        ious = area_intersect / area_union
+        gious = ious - (area_enclosing - area_union) / area_enclosing
+
+        losses = 1 - gious
+
+        if weight is not None and weight.sum() > 0:
+            return (losses * weight).sum()
+        else:
+            assert losses.numel() != 0
+            return losses.sum()
+
+    def prepare_targets(self, targets, anchors, tokenized=None, positive_map=None, proj_tokens=None):
+        cls_labels = []
+        reg_targets = []
+        token_labels = []
+        map_labels = []
+
+        gold_box_od_labels = []
+        od_label_of_tokens_labels = []
+        positive_indices = []
+
+        offset = 0
+
+        for im_i in range(len(targets)):
+            targets_per_im = targets[im_i]
+            assert targets_per_im.mode == "xyxy"
+            # bboxes_per_im = targets_per_im.get_field("boxes")
+            bboxes_per_im = targets_per_im.bbox
+            labels_per_im = targets_per_im.get_field("labels")
+            num_gt = len(bboxes_per_im)
+
+            if positive_map is not None:
+                token_per_im = positive_map[offset:offset + num_gt, :]
+                offset += num_gt
+
+            # Recheck if the label matches with the positive map
+            # print(labels_per_im)
+            # print(token_per_im.nonzero())
+
+            # shallow contrastive
+            if "original_od_label" in targets_per_im.fields():
+                gold_box_od_label = targets_per_im.get_field("original_od_label")
+            if "positive_map_for_od_labels" in targets_per_im.fields():
+                od_label_of_token_per_im = targets_per_im.get_field("positive_map_for_od_labels")
+
+            # print(gold_box_od_label)
+            # print(od_label_of_token_per_im)
+
+            if positive_map is not None and proj_tokens is not None:
+                if "tokens_positive" in targets_per_im.fields():
+                    cur_tokens = targets_per_im.get_field("tokens_positive")
+                else:
+                    cur_tokens = targets_per_im.get_field("tokens")
+                map = torch.zeros((len(cur_tokens), proj_tokens.shape[1]), dtype=torch.bool)
+                for j, tok_list in enumerate(cur_tokens):
+                    for (beg, end) in tok_list:
+                        beg_pos = tokenized.char_to_token(im_i, beg)
+                        end_pos = tokenized.char_to_token(im_i, end - 1)
+                        if beg_pos is None:
+                            try:
+                                beg_pos = tokenized.char_to_token(im_i, beg + 1)
+                                if beg_pos is None:
+                                    beg_pos = tokenized.char_to_token(im_i, beg + 2)
+                            except:
+                                beg_pos = None
+                        if end_pos is None:
+                            try:
+                                end_pos = tokenized.char_to_token(im_i, end - 2)
+                                if end_pos is None:
+                                    end_pos = tokenized.char_to_token(im_i, end - 3)
+                            except:
+                                end_pos = None
+                        if beg_pos is None or end_pos is None:
+                            continue
+
+                        assert beg_pos is not None and end_pos is not None
+                        map[j, beg_pos: end_pos + 1].fill_(True)
+
+            anchors_per_im = cat_boxlist(anchors[im_i])
+
+            num_anchors_per_loc = len(self.cfg.MODEL.RPN.ASPECT_RATIOS) * self.cfg.MODEL.RPN.SCALES_PER_OCTAVE
+            num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]]
+            ious = boxlist_iou(anchors_per_im, targets_per_im)
+
+            gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0
+            gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0
+            gt_points = torch.stack((gt_cx, gt_cy), dim=1)
+
+            anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0
+            anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0
+            anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1)
+
+            distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt()
+
+            # Selecting candidates based on the center distance between anchor box and object
+            candidate_idxs = []
+            star_idx = 0
+            for level, anchors_per_level in enumerate(anchors[im_i]):
+                end_idx = star_idx + num_anchors_per_level[level]
+                distances_per_level = distances[star_idx:end_idx, :]
+                topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level])
+                _, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False)
+                candidate_idxs.append(topk_idxs_per_level + star_idx)
+                star_idx = end_idx
+            candidate_idxs = torch.cat(candidate_idxs, dim=0)
+
+            # Using the sum of mean and standard deviation as the IoU threshold to select final positive samples
+            candidate_ious = ious[candidate_idxs, torch.arange(num_gt)]
+            iou_mean_per_gt = candidate_ious.mean(0)
+            iou_std_per_gt = candidate_ious.std(0)
+            iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt
+            is_pos = candidate_ious >= iou_thresh_per_gt[None, :]
+
+            # Limiting the final positive samples’ center to object
+            anchor_num = anchors_cx_per_im.shape[0]
+            for ng in range(num_gt):
+                candidate_idxs[:, ng] += ng * anchor_num
+            e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
+            e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
+            candidate_idxs = candidate_idxs.view(-1)
+            l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0]
+            t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1]
+            r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt)
+            b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt)
+            is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01
+            is_pos = is_pos & is_in_gts
+
+            # if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
+            ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1)
+            index = candidate_idxs.view(-1)[is_pos.view(-1)]
+            ious_inf[index] = ious.t().contiguous().view(-1)[index]
+            ious_inf = ious_inf.view(num_gt, -1).t()
+
+            anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)
+            # get positive anchors index from ATSS
+            positive_index = [i[0].item() for i in torch.nonzero(anchors_to_gt_indexs)]
+            cls_labels_per_im = labels_per_im[anchors_to_gt_indexs]
+            cls_labels_per_im[anchors_to_gt_values == -INF] = 0
+
+            if positive_map is not None:
+                token_labels_per_im = token_per_im[anchors_to_gt_indexs]
+                unmatched_labels = torch.zeros(token_labels_per_im.shape[1], device=token_labels_per_im.device)
+                # TODO: temporarially disable the [NoObj] token logic, and only restrict to binary loss
+                unmatched_labels[-1] = 1  # token: none object - > 256
+                token_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels
+                # move from cpu to gpu
+                token_labels_per_im = token_labels_per_im.to(cls_labels_per_im.device)
+
+                # print(token_labels_per_im[anchors_to_gt_values == -INF].shape)
+                # print(cls_labels_per_im[anchors_to_gt_values != -INF][0])
+                # print(token_labels_per_im[anchors_to_gt_values != -INF][0].nonzero())
+
+            if positive_map is not None and proj_tokens is not None:
+                map_labels_per_im = map[anchors_to_gt_indexs]
+                unmatched_labels = torch.zeros(map_labels_per_im.shape[1], dtype=torch.bool,
+                                               device=map_labels_per_im.device)  # map: none False
+                map_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels
+                # move from cpu to gpu
+                map_labels_per_im = map_labels_per_im.to(cls_labels_per_im.device)
+
+                # print(map_labels_per_im[anchors_to_gt_values == -INF].shape)
+                # print(map_labels_per_im[anchors_to_gt_values != -INF][0])
+
+            if positive_map is not None and proj_tokens is not None:
+                gold_box_od_label_per_im = gold_box_od_label[anchors_to_gt_indexs]
+                gold_box_od_label_per_im[anchors_to_gt_values == -INF] = -100
+                # move from cpu to gpu
+                gold_box_od_label_per_im = gold_box_od_label_per_im.to(cls_labels_per_im.device)
+
+                # print(gold_box_od_label_per_im[anchors_to_gt_values != -INF])
+
+            matched_gts = bboxes_per_im[anchors_to_gt_indexs]
+
+            reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox)
+            cls_labels.append(cls_labels_per_im)
+            reg_targets.append(reg_targets_per_im)
+
+            if positive_map is not None:
+                token_labels.append(token_labels_per_im)
+
+            if positive_map is not None and proj_tokens is not None:
+                map_labels.append(map_labels_per_im)
+                gold_box_od_labels.append(gold_box_od_label_per_im)
+                od_label_of_tokens_labels.append(od_label_of_token_per_im)
+                positive_indices.append(positive_index)
+
+        # print([len(x) for x in positive_indices])
+
+        return cls_labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices
+
+    def compute_centerness_targets(self, reg_targets, anchors):
+        gts = self.box_coder.decode(reg_targets, anchors)
+        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+        l = anchors_cx - gts[:, 0]
+        t = anchors_cy - gts[:, 1]
+        r = gts[:, 2] - anchors_cx
+        b = gts[:, 3] - anchors_cy
+        left_right = torch.stack([l, r], dim=1)
+        top_bottom = torch.stack([t, b], dim=1)
+        centerness = torch.sqrt((left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
+                                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
+        assert not torch.isnan(centerness).any()
+        return centerness
+
+    @custom_fwd(cast_inputs=torch.float32)
+    def __call__(self, box_cls, box_regression, centerness, targets, anchors,
+                 captions=None,
+                 positive_map=None,
+                 token_logits=None,
+                 proj_tokens=None,
+                 contrastive_logits=None,
+                 dot_product_logits=None,
+                 text_masks=None,
+                 shallow_img_emb_feats=None
+                 ):
+
+        tokenized = None
+        if captions is not None:
+            # tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
+            if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
+                tokenized = self.tokenizer.batch_encode_plus(captions,
+                                                             max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
+                                                             padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest",
+                                                             return_tensors='pt',
+                                                             truncation=True)
+            else:
+                tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
+
+        labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices = self.prepare_targets(targets, anchors,
+                                                                             tokenized,
+                                                                             positive_map,
+                                                                             proj_tokens
+                                                                             )
+
+        N = len(labels)
+
+        box_regression_flatten, box_cls_flatten, token_logits_stacked = concat_box_prediction_layers(
+            box_regression,
+            box_cls,
+            token_logits,
+        )
+
+        # contrastive logits
+        if positive_map is not None and contrastive_logits is not None:
+            contrastive_logits = torch.cat(contrastive_logits, dim=1)
+
+        # dot product soft token logits
+        if dot_product_logits is not None:
+            dot_product_logits = torch.cat(dot_product_logits, dim=1)
+
+        centerness_flatten = [ct.permute(0, 2, 3, 1).reshape(N, -1, 1) for ct in centerness]
+        centerness_flatten = torch.cat(centerness_flatten, dim=1).reshape(-1)
+
+        labels_flatten = torch.cat(labels, dim=0)
+        reg_targets_flatten = torch.cat(reg_targets, dim=0)
+        anchors_flatten = torch.cat([cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors], dim=0)
+
+        if positive_map is not None:
+            token_labels_stacked = torch.stack(token_labels, dim=0)
+
+        if positive_map is not None and proj_tokens is not None:
+            positive_map_box_to_self_text = None
+            shallow_positive_map = None
+            bs = proj_tokens.shape[0]
+            device = proj_tokens.device
+
+            # NOTE: 0. setup env
+            if dist.is_dist_avail_and_initialized():
+                world_size = dist.get_world_size()
+                rank = torch.distributed.get_rank()
+            else:
+                world_size = 1
+                rank = 0
+
+            if contrastive_logits is not None:
+                positive_map_box_to_self_text = torch.stack(map_labels, dim=0)
+
+            if shallow_img_emb_feats is not None:
+                '''
+                Ultimate:
+                    N*B*(max_anchor_num) x N*B*T
+                Final Goal:
+                    F = B x (max_anchor_num) x N*B*T
+                        X: B x (max_anchor_num) od_labels : [0, 20, 30, ..]
+                        Y: N*B*T: which denotes the od_label of every token
+                    F[i,j] = A[i] == B[j]
+                '''
+                with torch.no_grad():
+                    # NOTE: 1. get X (predicted_box_od_label), which the detection label of every predicted boxes
+                    # predicted_box_od_label: B x A
+
+                    # check memory limitation: prevent # of positive >= # of max_positive
+                    new_positive_indices = []
+                    # print([len(positive_index) for positive_index in positive_indices])
+                    for positive_index in positive_indices:
+                        if len(positive_index) >= self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS:
+                            import random
+                            positive_index = sorted(random.sample(positive_index,
+                                                           self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS))
+                        new_positive_indices.append(positive_index)
+                    # print([len(positive_index) for positive_index in positive_indices])
+
+                    max_len = max([len(positive_index) for positive_index in new_positive_indices])
+                    max_anchor_num = max_len
+
+                    if world_size > 1:
+                        num_anchors = torch.tensor(max_len, device=positive_map.device)
+                        num_anchors_full = [torch.zeros_like(num_anchors) for _ in range(world_size)]
+                        torch.distributed.all_gather(num_anchors_full, num_anchors)
+                        max_anchor_num = max([anchor.item() for anchor in num_anchors_full])
+
+                    new_negative_pad_indices = []
+                    # if not PAD_ZEROS, select random negative paddings
+                    if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS:
+                        for (positive_index, old_positive_index) in zip(new_positive_indices, positive_indices):
+                            negative_index = [i for i in range(len(cat_boxlist(anchors[0]))) if i not in old_positive_index]
+                            import random
+                            negative_pad_index = sorted(random.sample(negative_index,
+                                                               max_anchor_num - len(positive_index)))
+                            new_negative_pad_indices.append(negative_pad_index)
+
+                    predicted_box_od_label = []
+                    for i in range(bs):
+                        predicted_box_od_label.append(
+                            pad_tensor_given_dim_length(gold_box_od_labels[i][new_positive_indices[i]],
+                                                        dim=0,
+                                                        length=max_anchor_num,
+                                                        padding_value=-100,
+                                                        batch_first=False
+                                                        ))
+                    predicted_box_od_label = torch.stack(predicted_box_od_label, dim=0)
+
+                    # if padding, need to create image masks to filter out the paddings
+                    image_masks = None
+                    if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS:
+                        image_masks = torch.zeros((bs, max_anchor_num), dtype=torch.long).to(text_masks.device)
+                        for i in range(bs):
+                            image_masks[i, :len(new_positive_indices[i])] = 1
+
+                    # NOTE: 2. Get Y (od_label_of_tokens)
+                    # od_label_of_tokens: N x B x T
+                    od_label_of_tokens = torch.stack(od_label_of_tokens_labels, dim=0).long()
+                    od_label_of_tokens = gather_tensors(od_label_of_tokens)
+
+                    # NOTE: 3. get F
+                    # F: B*A x N*B*T
+                    mapping_predicted_box_to_all_text = predicted_box_od_label.view(-1).unsqueeze(
+                        1) == od_label_of_tokens.view(-1).unsqueeze(0)
+
+                    # NOTE: 4. we still need to calculate the mapping between predicted box to its corresponding text's mapping
+                    # positive_map_box_to_self_text: B x A x T, leave this for vanilla contrastive alignment loss
+                    positive_map_box_to_self_text = []
+                    for i in range(bs):
+                        positive_map_box_to_self_text.append(
+                            pad_tensor_given_dim_length(map_labels[i][new_positive_indices[i]],
+                                                        dim=0,
+                                                        length=max_anchor_num,
+                                                        padding_value=False,
+                                                        batch_first=False
+                                                        ))
+                    positive_map_box_to_self_text = torch.stack(positive_map_box_to_self_text, dim=0)
+
+                    # change the corresponding place in our batch
+                    for i in range(bs):
+                        mapping_predicted_box_to_all_text[i * max_anchor_num: (i + 1) * max_anchor_num,
+                        (rank * bs + i) * 256: (rank * bs + i + 1) * 256] = positive_map_box_to_self_text[i]
+
+                    # NOTE: 5. communicate and get positive map
+                    # mapping_predicted_box_to_all_text: N*B*A x N*B*T
+                    mapping_predicted_box_to_all_text = gather_tensors(mapping_predicted_box_to_all_text).view(-1,
+                                                                                                               mapping_predicted_box_to_all_text.size(
+                                                                                                                   -1))
+                    shallow_positive_map = mapping_predicted_box_to_all_text  # This is the true positive map
+                    shallow_positive_map = shallow_positive_map.unsqueeze(0)
+
+                    # Get text attention masks
+                    text_attention_mask = torch.zeros((bs, 256), dtype=torch.long)  # B x 256
+                    for i in range(bs):
+                        text_attention_mask[i, :len(text_masks[i])] = text_masks[i]
+                    text_attention_mask = gather_tensors(
+                        text_attention_mask.bool().to(device))  # N x B x 256
+
+                    # if PAD_ZEROS, get image masks
+                    if image_masks is not None:
+                        image_attention_mask = torch.zeros((bs, max_anchor_num), dtype=torch.long)  # B x max_anchor
+                        for i in range(bs):
+                            image_attention_mask[i, :len(image_masks[i])] = image_masks[i]
+                        image_attention_mask = gather_tensors(
+                            image_attention_mask.bool().to(device))  # N x B x max_anchor
+
+                # NOTE: 6. calculate shallow contrastive logits
+                shallow_proj_tokens = F.normalize(self.shallow_contrastive_projection_text(proj_tokens), p=2, dim=-1)
+
+                shallow_normalized_img_embs = []
+                if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+                    # choice 1:use features from SWINT backbone layer (c4) before vl fusion
+                    from maskrcnn_benchmark.layers.roi_align import ROIAlignV2
+                    pooler = ROIAlignV2((1, 1), 1./16, 0)
+                    # get positive features
+                    for i in range(bs):
+                        rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_positive_indices[i]])
+                        roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), rois)
+                        roi_feature = roi_feature.squeeze(-1).squeeze(-1)
+                        shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(roi_feature)
+                        shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1)
+                        if image_masks is not None:
+                            # pad zeros
+                            shallow_normalized_img_embs.append(
+                                pad_tensor_given_dim_length(shallow_normalized_img_emb,
+                                                            dim=0,
+                                                            length=max_anchor_num,
+                                                            padding_value=0.0,
+                                                            batch_first=False
+                                                            ))
+                        else:
+                            # pad negatives
+                            negative_rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_negative_pad_indices[i]])
+                            negative_roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), negative_rois)
+                            negative_roi_feature = negative_roi_feature.squeeze(-1).squeeze(-1)
+                            negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(negative_roi_feature)
+                            negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries,
+                                                                              p=2, dim=-1)
+                            shallow_normalized_img_embs.append(
+                                pad_random_negative_tensor_given_length(shallow_normalized_img_emb,
+                                                                        negative_shallow_normalized_img_emb,
+                                                                        length=max_anchor_num
+                                                                        )
+                            )
+                elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
+                    # choice 2:use features after FPN
+                    shallow_img_embs = torch.cat(shallow_img_emb_feats, dim=1)
+                    # get positive features
+                    for i in range(bs):
+                        shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_positive_indices[i], :])
+                        shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1)
+                        if image_masks is not None:
+                            # pad zeros
+                            shallow_normalized_img_embs.append(
+                                pad_tensor_given_dim_length(shallow_normalized_img_emb,
+                                                            dim=0,
+                                                            length=max_anchor_num,
+                                                            padding_value=0.0,
+                                                            batch_first=False
+                                                            ))
+                        else:
+                            # pad negatives
+                            negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_negative_pad_indices[i], :])
+                            negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries,
+                                                                              p=2, dim=-1)
+                            shallow_normalized_img_embs.append(
+                                pad_random_negative_tensor_given_length(shallow_normalized_img_emb,
+                                                                        negative_shallow_normalized_img_emb,
+                                                                        length=max_anchor_num
+                                                                        )
+                            )
+
+                shallow_normalized_img_embs = torch.stack(shallow_normalized_img_embs, dim=0)
+                shallow_normalized_text_emb = shallow_proj_tokens
+                shallow_normalized_text_emb = pad_tensor_given_dim_length(shallow_normalized_text_emb,
+                                                                          dim=1,
+                                                                          length=256,
+                                                                          padding_value=0.0)
+
+                gathered_shallow_normalized_img_emb = gather_tensors(shallow_normalized_img_embs)
+                gathered_shallow_normalized_text_emb = gather_tensors(shallow_normalized_text_emb)
+                gathered_shallow_normalized_img_emb = gathered_shallow_normalized_img_emb.view(-1,
+                                                                                               gathered_shallow_normalized_img_emb.size(
+                                                                                                   -1))
+                gathered_shallow_normalized_text_emb = gathered_shallow_normalized_text_emb.view(-1,
+                                                                                                 gathered_shallow_normalized_text_emb.size(
+                                                                                                     -1))
+                shallow_contrastive_logits = (
+                        torch.matmul(gathered_shallow_normalized_img_emb,
+                                     gathered_shallow_normalized_text_emb.transpose(-1,
+                                                                                    -2)) / self.shallow_log_scale.exp())
+                shallow_contrastive_logits = shallow_contrastive_logits.unsqueeze(0)
+
+                # apply text mask
+                text_attention_mask = text_attention_mask.view(-1).unsqueeze(0).unsqueeze(0)
+                text_attention_mask = text_attention_mask.repeat(1, shallow_contrastive_logits.size(1),
+                                                                 1)  # copy along the image feature dimension
+                shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~text_attention_mask, -1000000)
+
+                # if PAD ZEROS, apply image mask
+                if image_masks is not None:
+                    image_attention_mask = image_attention_mask.view(-1).unsqueeze(0).unsqueeze(-1)
+                    image_attention_mask = image_attention_mask.repeat(1, 1, shallow_contrastive_logits.size(
+                        2))  # copy along the text feature dimension
+                    shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~image_attention_mask, -1000000)
+
+                # Note: 7. calculate image and text logits and maps
+                shallow_image_logits = shallow_contrastive_logits[:,
+                                       (rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :]
+                shallow_image_positive_map = normalized_positive_map(
+                    shallow_positive_map[:, (rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :])
+
+                shallow_text_logits = shallow_contrastive_logits[:, :,
+                                      (rank * bs) * 256: (rank * bs + bs) * 256].transpose(1,
+                                                                                           2)
+                shallow_text_positive_map = normalized_positive_map(
+                    shallow_positive_map[:, :, (rank * bs) * 256: (rank * bs + bs) * 256].transpose(1, 2))
+
+        pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
+
+        num_gpus = get_world_size()
+        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()])).item()
+        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)
+
+        cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu
+
+        token_logits_loss = None
+        contrastive_align_loss = None
+        dot_product_token_loss = None
+        shallow_contrastive_loss = None
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+            token_logits_loss = self.token_loss_func(token_logits_stacked,
+                                                     token_labels_stacked, text_masks=text_masks,
+                                                     version="binary") / num_pos_avg_per_gpu
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            contrastive_align_loss = self.ContrastiveAlignLoss(contrastive_logits, positive_map_box_to_self_text) / num_pos_avg_per_gpu
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            dot_product_token_loss = self.token_loss_func(dot_product_logits,
+                                                          token_labels_stacked, text_masks=text_masks,
+                                                          version="binary") / num_pos_avg_per_gpu
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \
+                self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+            box_to_token_loss = self.NllSoftMaxLoss(shallow_image_logits, shallow_image_positive_map).sum()
+            token_to_box_loss = self.NllSoftMaxLoss(shallow_text_logits, shallow_text_positive_map).sum()
+            tot_loss = (box_to_token_loss + token_to_box_loss) / 2
+            shallow_contrastive_loss = tot_loss / num_pos_avg_per_gpu
+
+        box_regression_flatten = box_regression_flatten[pos_inds]
+        reg_targets_flatten = reg_targets_flatten[pos_inds]
+        anchors_flatten = anchors_flatten[pos_inds]
+        centerness_flatten = centerness_flatten[pos_inds]
+
+        if pos_inds.numel() > 0:
+            centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten)
+
+            sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus)
+            reg_loss = self.GIoULoss(box_regression_flatten, reg_targets_flatten, anchors_flatten,
+                                     weight=centerness_targets) / sum_centerness_targets_avg_per_gpu
+            centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / num_pos_avg_per_gpu
+        else:
+            reg_loss = box_regression_flatten.sum()
+            reduce_sum(centerness_flatten.new_tensor([0.0]))
+            centerness_loss = centerness_flatten.sum()
+
+        return cls_loss, reg_loss * self.cfg.MODEL.ATSS.REG_LOSS_WEIGHT, centerness_loss, \
+               token_logits_loss, \
+               contrastive_align_loss, \
+               dot_product_token_loss, \
+               shallow_contrastive_loss
+
+
+def generate_anchor_labels(matched_targets):
+    labels_per_image = matched_targets.get_field("labels")
+    return labels_per_image
+
+
+def make_focal_loss_evaluator(cfg, box_coder):
+    matcher = Matcher(
+        cfg.MODEL.FOCAL.FG_IOU_THRESHOLD,
+        cfg.MODEL.FOCAL.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=True,
+    )
+    sigmoid_focal_loss = SigmoidFocalLoss(
+        cfg.MODEL.FOCAL.LOSS_GAMMA,
+        cfg.MODEL.FOCAL.LOSS_ALPHA
+    )
+
+    loss_evaluator = FocalLossComputation(
+        matcher,
+        box_coder,
+        generate_anchor_labels,
+        sigmoid_focal_loss,
+        bbox_reg_beta=cfg.MODEL.FOCAL.BBOX_REG_BETA,
+        regress_norm=cfg.MODEL.FOCAL.BBOX_REG_WEIGHT,
+    )
+    return loss_evaluator
+
+
+def make_rpn_loss_evaluator(cfg, box_coder):
+    matcher = Matcher(
+        cfg.MODEL.RPN.FG_IOU_THRESHOLD,
+        cfg.MODEL.RPN.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=True,
+    )
+
+    fg_bg_sampler = BalancedPositiveNegativeSampler(
+        cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION
+    )
+
+    loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder)
+    return loss_evaluator
+
+
+def make_fcos_loss_evaluator(cfg):
+    loss_evaluator = FCOSLossComputation(cfg)
+    return loss_evaluator
+
+
+def make_atss_loss_evaluator(cfg, box_coder):
+    loss_evaluator = ATSSLossComputation(cfg, box_coder)
+    return loss_evaluator
diff --git a/maskrcnn_benchmark/modeling/rpn/modeling_bert.py b/maskrcnn_benchmark/modeling/rpn/modeling_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7eda26e6e13262cb281d7a53acd2f5e515fd391
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/modeling_bert.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model. """
+
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+import pdb
+from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
+
+
+def clamp_values(vector, min_val = -50000, max_val = 50000):
+    vector = torch.clamp(vector, min = min_val, max = max_val)
+    return vector
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, clamp_min_for_underflow=False, clamp_max_for_overflow=False):
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+                f"heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+        self.clamp_min_for_underflow = clamp_min_for_underflow
+        self.clamp_max_for_overflow = clamp_max_for_overflow
+
+        self.is_decoder = config.is_decoder
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention and past_key_value is not None:
+            # reuse k,v, cross_attentions
+            key_layer = past_key_value[0]
+            value_layer = past_key_value[1]
+            attention_mask = encoder_attention_mask
+        elif is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_layer, value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        if self.clamp_min_for_underflow:
+            attention_scores = torch.clamp(attention_scores, min=-50000) # Do not increase -50000, data type half has quite limited range
+        if self.clamp_max_for_overflow:
+            attention_scores = torch.clamp(attention_scores, max=50000) # Do not increase 50000, data type half has quite limited range
+
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+        # if math.isnan(attention_probs.sum().item()):
+        #     for i in range(attention_probs.size(1)):
+        #         for j in range(attention_probs.size(2)):
+        #             if math.isnan(attention_probs[0, i, j].sum().item()):
+        #                 print(i, j)
+        #                 pdb.set_trace()
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        if self.is_decoder:
+            outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, clamp_min_for_underflow=False, clamp_max_for_overflow=False):
+        super().__init__()
+        self.self = BertSelfAttention(config, clamp_min_for_underflow, clamp_max_for_overflow)
+        self.output = BertSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = clamp_values(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        hidden_states = clamp_values(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = clamp_values(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        hidden_states = clamp_values(hidden_states)
+        return hidden_states
+
diff --git a/maskrcnn_benchmark/modeling/rpn/retina.py b/maskrcnn_benchmark/modeling/rpn/retina.py
new file mode 100644
index 0000000000000000000000000000000000000000..146449c7cc930bef93d89471d021979bdea7546e
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/retina.py
@@ -0,0 +1,156 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+from .loss import make_focal_loss_evaluator
+from .anchor_generator import make_anchor_generator_complex
+from .inference import make_retina_postprocessor
+
+
+@registry.RPN_HEADS.register("RetinaNetHead")
+class RetinaNetHead(torch.nn.Module):
+    """
+    Adds a RetinNet head with classification and regression heads
+    """
+
+    def __init__(self, cfg):
+        """
+        Arguments:
+            in_channels (int): number of channels of the input feature
+            num_anchors (int): number of anchors to be predicted
+        """
+        super(RetinaNetHead, self).__init__()
+        # TODO: Implement the sigmoid version first.
+        num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        if cfg.MODEL.RPN.USE_FPN:
+            num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
+        else:
+            num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * len(cfg.MODEL.RPN.ANCHOR_SIZES)
+
+        cls_tower = []
+        bbox_tower = []
+        for i in range(cfg.MODEL.RETINANET.NUM_CONVS):
+            cls_tower.append(
+                nn.Conv2d(
+                    in_channels,
+                    in_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1
+                )
+            )
+            cls_tower.append(nn.ReLU())
+            bbox_tower.append(
+                nn.Conv2d(
+                    in_channels,
+                    in_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1
+                )
+            )
+            bbox_tower.append(nn.ReLU())
+
+        self.add_module('cls_tower', nn.Sequential(*cls_tower))
+        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
+        self.cls_logits = nn.Conv2d(
+            in_channels, num_anchors * num_classes, kernel_size=3, stride=1,
+            padding=1
+        )
+        self.bbox_pred = nn.Conv2d(
+            in_channels,  num_anchors * 4, kernel_size=3, stride=1,
+            padding=1
+        )
+
+        # Initialization
+        for modules in [self.cls_tower, self.bbox_tower, self.cls_logits,
+                  self.bbox_pred]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    torch.nn.init.constant_(l.bias, 0)
+
+
+        # retinanet_bias_init
+        prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        for feature in x:
+            logits.append(self.cls_logits(self.cls_tower(feature)))
+            bbox_reg.append(self.bbox_pred(self.bbox_tower(feature)))
+        return logits, bbox_reg
+
+
+class RetinaNetModule(torch.nn.Module):
+    """
+    Module for RetinaNet computation. Takes feature maps from the backbone and
+    RetinaNet outputs and losses. Only Test on FPN now.
+    """
+
+    def __init__(self, cfg):
+        super(RetinaNetModule, self).__init__()
+
+        self.cfg = cfg.clone()
+
+        anchor_generator = make_anchor_generator_complex(cfg)
+        head = RetinaNetHead(cfg)
+
+        box_coder = BoxCoder(weights=(10., 10., 5., 5.))
+
+        box_selector_test = make_retina_postprocessor(cfg, box_coder, is_train=False)
+
+        loss_evaluator = make_focal_loss_evaluator(cfg, box_coder)
+
+        self.anchor_generator = anchor_generator
+        self.head = head
+        self.box_selector_test = box_selector_test
+        self.loss_evaluator = loss_evaluator
+
+    def forward(self, images, features, targets=None):
+        """
+        Arguments:
+            images (ImageList): images for which we want to compute the predictions
+            features (list[Tensor]): features computed from the images that are
+                used for computing the predictions. Each tensor in the list
+                correspond to different feature levels
+            targets (list[BoxList): ground-truth boxes present in the image (optional)
+
+        Returns:
+            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
+                image.
+            losses (dict[Tensor]): the losses for the model during training. During
+                testing, it is an empty dict.
+        """
+        box_cls, box_regression = self.head(features)
+        anchors = self.anchor_generator(images, features)
+
+        if self.training:
+            return self._forward_train(anchors, box_cls, box_regression, targets)
+        else:
+            return self._forward_test(anchors, box_cls, box_regression)
+
+    def _forward_train(self, anchors, box_cls, box_regression, targets):
+
+        loss_box_cls, loss_box_reg = self.loss_evaluator(
+            anchors, box_cls, box_regression, targets
+        )
+        losses = {
+            "loss_retina_cls": loss_box_cls,
+            "loss_retina_reg": loss_box_reg,
+        }
+        return anchors, losses
+
+    def _forward_test(self, anchors, box_cls, box_regression):
+        boxes = self.box_selector_test(anchors, box_cls, box_regression)
+        return boxes, {}
+
+
diff --git a/maskrcnn_benchmark/modeling/rpn/rpn.py b/maskrcnn_benchmark/modeling/rpn/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f300f67773358a11d890999a556a0dbea3bfdeb
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/rpn.py
@@ -0,0 +1,171 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+from .loss import make_rpn_loss_evaluator
+from .anchor_generator import make_anchor_generator
+from .inference import make_rpn_postprocessor
+
+
+@registry.RPN_HEADS.register("SimpleRPNHead")
+class mRPNHead(nn.Module):
+    """
+    Adds a simple RPN Head with classification and regression heads
+    """
+
+    def __init__(self, cfg, in_channels, num_anchors):
+        """
+        Arguments:
+            cfg              : config
+            in_channels (int): number of channels of the input feature
+            num_anchors (int): number of anchors to be predicted
+        """
+        super(mRPNHead, self).__init__()
+        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
+        self.bbox_pred = nn.Conv2d(
+            in_channels, num_anchors * 4, kernel_size=1, stride=1
+        )
+
+        for l in [self.cls_logits, self.bbox_pred]:
+            torch.nn.init.normal_(l.weight, std=0.01)
+            torch.nn.init.constant_(l.bias, 0)
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        for feature in x:
+            t = F.relu(feature)
+            logits.append(self.cls_logits(t))
+            bbox_reg.append(self.bbox_pred(t))
+        return logits, bbox_reg
+
+
+@registry.RPN_HEADS.register("SingleConvRPNHead")
+class RPNHead(nn.Module):
+    """
+    Adds a simple RPN Head with classification and regression heads
+    """
+
+    def __init__(self, cfg, in_channels, num_anchors):
+        """
+        Arguments:
+            cfg              : config
+            in_channels (int): number of channels of the input feature
+            num_anchors (int): number of anchors to be predicted
+        """
+        super(RPNHead, self).__init__()
+        self.conv = nn.Conv2d(
+            in_channels, in_channels, kernel_size=3, stride=1, padding=1
+        )
+        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
+        self.bbox_pred = nn.Conv2d(
+            in_channels, num_anchors * 4, kernel_size=1, stride=1
+        )
+
+        for l in [self.conv, self.cls_logits, self.bbox_pred]:
+            torch.nn.init.normal_(l.weight, std=0.01)
+            torch.nn.init.constant_(l.bias, 0)
+
+    def forward(self, x):
+        logits = []
+        bbox_reg = []
+        for feature in x:
+            t = F.relu(self.conv(feature))
+            logits.append(self.cls_logits(t))
+            bbox_reg.append(self.bbox_pred(t))
+        return logits, bbox_reg
+
+
+class RPNModule(torch.nn.Module):
+    """
+    Module for RPN computation. Takes feature maps from the backbone and RPN
+    proposals and losses. Works for both FPN and non-FPN.
+    """
+
+    def __init__(self, cfg):
+        super(RPNModule, self).__init__()
+
+        self.cfg = cfg.clone()
+
+        anchor_generator = make_anchor_generator(cfg)
+
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
+        head = rpn_head(
+            cfg, in_channels, anchor_generator.num_anchors_per_location()[0]
+        )
+
+        rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+        box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True)
+        box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False)
+
+        loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder)
+
+        self.anchor_generator = anchor_generator
+        self.head = head
+        self.box_selector_train = box_selector_train
+        self.box_selector_test = box_selector_test
+        self.loss_evaluator = loss_evaluator
+
+    def forward(self, images, features, targets=None):
+        """
+        Arguments:
+            images (ImageList): images for which we want to compute the predictions
+            features (list[Tensor]): features computed from the images that are
+                used for computing the predictions. Each tensor in the list
+                correspond to different feature levels
+            targets (list[BoxList): ground-truth boxes present in the image (optional)
+
+        Returns:
+            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
+                image.
+            losses (dict[Tensor]): the losses for the model during training. During
+                testing, it is an empty dict.
+        """
+        objectness, rpn_box_regression = self.head(features)
+        anchors = self.anchor_generator(images, features)
+
+        if self.training:
+            return self._forward_train(anchors, objectness, rpn_box_regression, targets)
+        else:
+            return self._forward_test(anchors, objectness, rpn_box_regression)
+
+    def _forward_train(self, anchors, objectness, rpn_box_regression, targets):
+        if self.cfg.MODEL.RPN_ONLY:
+            # When training an RPN-only model, the loss is determined by the
+            # predicted objectness and rpn_box_regression values and there is
+            # no need to transform the anchors into predicted boxes; this is an
+            # optimization that avoids the unnecessary transformation.
+            boxes = anchors
+        else:
+            # For end-to-end models, anchors must be transformed into boxes and
+            # sampled into a training batch.
+            with torch.no_grad():
+                boxes = self.box_selector_train(
+                    anchors, objectness, rpn_box_regression, targets
+                )
+        loss_objectness, loss_rpn_box_reg = self.loss_evaluator(
+            anchors, objectness, rpn_box_regression, targets
+        )
+        losses = {
+            "loss_objectness": loss_objectness,
+            "loss_rpn_box_reg": loss_rpn_box_reg,
+        }
+        return boxes, losses
+
+    def _forward_test(self, anchors, objectness, rpn_box_regression):
+        boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
+        if self.cfg.MODEL.RPN_ONLY:
+            # For end-to-end models, the RPN proposals are an intermediate state
+            # and don't bother to sort them in decreasing score order. For RPN-only
+            # models, the proposals are the final output and we return them in
+            # high-to-low confidence order.
+            inds = [
+                box.get_field("objectness").sort(descending=True)[1] for box in boxes
+            ]
+            boxes = [box[ind] for box, ind in zip(boxes, inds)]
+        return boxes, {}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/modeling/rpn/transformer.py b/maskrcnn_benchmark/modeling/rpn/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f0cd1efc216113cb3ef78896356cc3c35c6354
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/transformer.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+import copy
+from typing import Optional, List
+
+
+def _get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+                 activation="relu", normalize_before=False):
+        super(TransformerEncoderLayer, self).__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def forward(self, src,
+                src_mask: Optional[Tensor] = None,
+                src_key_padding_mask: Optional[Tensor] = None):
+        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
+                              key_padding_mask=src_key_padding_mask)[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src
diff --git a/maskrcnn_benchmark/modeling/rpn/vldyhead.py b/maskrcnn_benchmark/modeling/rpn/vldyhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edbb5d477c80e9abe760320fb7311fcc3efdcbe
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/rpn/vldyhead.py
@@ -0,0 +1,1036 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from collections import defaultdict
+
+from .inference import make_atss_postprocessor
+from .loss import make_atss_loss_evaluator
+from .anchor_generator import make_anchor_generator_complex
+
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv
+from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
+from maskrcnn_benchmark.modeling.backbone.fbnet import *
+from maskrcnn_benchmark.engine.inference import create_positive_map_label_to_token_from_positive_map
+from ..utils import cat, concat_box_prediction_layers, permute_and_flatten
+
+from maskrcnn_benchmark.utils.fuse_helper import FeatureResizer, func_attention, _make_mlp, _make_conv, _make_coord, \
+    BiAttentionBlock, AttentionT2I, BiAttentionBlockForCheckpoint, BertLMPredictionHead
+from transformers.models.bert.modeling_bert import BertConfig, BertAttention, BertIntermediate, BertOutput, \
+    BertPreTrainedModel
+from transformers.modeling_utils import apply_chunking_to_forward
+import torch.utils.checkpoint as checkpoint
+import pdb
+
+from maskrcnn_benchmark.modeling.language_backbone.clip_model import QuickGELU, LayerNorm, DropPath
+from timm.models.layers import DropPath, trunc_normal_
+
+class h_sigmoid(nn.Module):
+    def __init__(self, inplace=True, h_max=1):
+        super(h_sigmoid, self).__init__()
+        self.relu = nn.ReLU6(inplace=inplace)
+        self.h_max = h_max
+
+    def forward(self, x):
+        return self.relu(x + 3) * self.h_max / 6
+
+
+class BoxCoder(object):
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def encode(self, gt_boxes, anchors):
+        TO_REMOVE = 1  # TODO remove
+        ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE
+        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE
+        gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2
+        gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+        targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+        targets_dw = ww * torch.log(gt_widths / ex_widths)
+        targets_dh = wh * torch.log(gt_heights / ex_heights)
+        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+
+        return targets
+
+    def decode(self, preds, anchors):
+        anchors = anchors.to(preds.dtype)
+
+        TO_REMOVE = 1  # TODO remove
+        widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
+        heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
+        ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
+        ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
+
+        wx, wy, ww, wh = (10., 10., 5., 5.)
+        dx = preds[:, 0::4] / wx
+        dy = preds[:, 1::4] / wy
+        dw = preds[:, 2::4] / ww
+        dh = preds[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=math.log(1000. / 16))
+        dh = torch.clamp(dh, max=math.log(1000. / 16))
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        pred_boxes = torch.zeros_like(preds)
+        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1)
+        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1)
+        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1)
+        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1)
+
+        return pred_boxes
+
+
+class Conv3x3Norm(torch.nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 groups=1,
+                 deformable=False,
+                 bn_type=None):
+        super(Conv3x3Norm, self).__init__()
+
+        if deformable:
+            self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1,
+                                            groups=groups)
+        else:
+            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups)
+
+        if isinstance(bn_type, (list, tuple)):
+            assert len(bn_type) == 2
+            assert bn_type[0] == "gn"
+            gn_group = bn_type[1]
+            bn_type = bn_type[0]
+
+        if bn_type == "bn":
+            bn_op = nn.BatchNorm2d(out_channels)
+        elif bn_type == "sbn":
+            bn_op = nn.SyncBatchNorm(out_channels)
+        elif bn_type == "nsbn":
+            bn_op = NaiveSyncBatchNorm2d(out_channels)
+        elif bn_type == "gn":
+            bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels)
+        elif bn_type == "af":
+            bn_op = FrozenBatchNorm2d(out_channels)
+        if bn_type is not None:
+            self.bn = bn_op
+        else:
+            self.bn = None
+
+    def forward(self, input, **kwargs):
+        x = self.conv(input, **kwargs)
+        if self.bn:
+            x = self.bn(x)
+        return x
+
+
+class DyConv(torch.nn.Module):
+    def __init__(self,
+                 in_channels=256,
+                 out_channels=256,
+                 conv_func=nn.Conv2d,
+                 use_dyfuse=True,
+                 use_dyrelu=False,
+                 use_deform=False
+                 ):
+        super(DyConv, self).__init__()
+
+        self.DyConv = nn.ModuleList()
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 1))
+        self.DyConv.append(conv_func(in_channels, out_channels, 2))
+
+        if use_dyfuse:
+            self.AttnConv = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(in_channels, 1, kernel_size=1),
+                nn.ReLU(inplace=True))
+            self.h_sigmoid = h_sigmoid()
+        else:
+            self.AttnConv = None
+
+        if use_dyrelu:
+            self.relu = DYReLU(in_channels, out_channels)
+        else:
+            self.relu = nn.ReLU()
+
+        if use_deform:
+            self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
+        else:
+            self.offset = None
+
+        self.init_weights()
+
+    def init_weights(self):
+        for m in self.DyConv.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight.data, 0, 0.01)
+                if m.bias is not None:
+                    m.bias.data.zero_()
+        if self.AttnConv is not None:
+            for m in self.AttnConv.modules():
+                if isinstance(m, nn.Conv2d):
+                    nn.init.normal_(m.weight.data, 0, 0.01)
+                    if m.bias is not None:
+                        m.bias.data.zero_()
+
+    def forward(self, inputs):
+        visual_feats = inputs["visual"]
+        language_dict_features = inputs["lang"]
+
+        next_x = []
+        for level, feature in enumerate(visual_feats):
+
+            conv_args = dict()
+            if self.offset is not None:
+                offset_mask = self.offset(feature)
+                offset = offset_mask[:, :18, :, :]
+                mask = offset_mask[:, 18:, :, :].sigmoid()
+                conv_args = dict(offset=offset, mask=mask)
+
+            temp_fea = [self.DyConv[1](feature, **conv_args)]
+
+            if level > 0:
+                temp_fea.append(self.DyConv[2](visual_feats[level - 1], **conv_args))
+            if level < len(visual_feats) - 1:
+                temp_fea.append(F.upsample_bilinear(self.DyConv[0](visual_feats[level + 1], **conv_args),
+                                                    size=[feature.size(2), feature.size(3)]))
+            mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
+
+            if self.AttnConv is not None:
+                attn_fea = []
+                res_fea = []
+                for fea in temp_fea:
+                    res_fea.append(fea)
+                    attn_fea.append(self.AttnConv(fea))
+
+                res_fea = torch.stack(res_fea)
+                spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
+
+                mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
+
+            next_x.append(mean_fea)
+
+        next_x = [self.relu(item) for item in next_x]
+
+        features_dict = {"visual": next_x,
+                         "lang": language_dict_features}
+
+        return features_dict
+
+
+class BertEncoderLayer(BertPreTrainedModel):
+    def __init__(self, config,  clamp_min_for_underflow = False, clamp_max_for_overflow = False):
+        super().__init__(config)
+        self.config = config
+
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+
+        from maskrcnn_benchmark.modeling.rpn.modeling_bert import BertAttention, BertIntermediate, BertOutput
+
+        self.attention = BertAttention(config,  clamp_min_for_underflow, clamp_max_for_overflow)
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+    def forward(self, inputs):
+        language_dict_features = inputs["lang"]
+        hidden_states = language_dict_features["hidden"]
+        attention_mask = language_dict_features["masks"]
+
+        device = hidden_states.device
+        input_shape = hidden_states.size()[:-1]
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+        self_attention_outputs = self.attention(
+            hidden_states,
+            extended_attention_mask,
+            None,
+            output_attentions=False,
+            past_key_value=None,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+        hidden_states = outputs[0]
+
+        language_dict_features["hidden"] = hidden_states
+
+        features_dict = {"visual": inputs["visual"],
+                         "lang": language_dict_features
+                         }
+
+        return features_dict
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class CLIPTransformerLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        d_model = self.config.MODEL.CLIP.WIDTH
+        n_head = self.config.MODEL.CLIP.HEADS
+        drop_path = self.config.MODEL.CLIP.DROP_PATH
+        self.context_length = self.config.MODEL.CLIP.CONTEXT_LENGTH
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = None
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Linear, nn.Conv2d)):
+            trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
+            nn.init.constant_(m.bias, 0)
+
+    def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
+            if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
+
+    def forward(self, inputs):
+        language_dict_features = inputs["lang"]
+        x = language_dict_features["hidden"]
+        mask = language_dict_features["masks"]
+        # get extended attention mask for nn.MultiHeadAttention
+        key_padding_mask = (1.0 - mask).to(torch.bool)
+
+        x = x.permute(1, 0, 2)
+        x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
+        x = x + self.drop_path(self.mlp(self.ln_2(x)))
+        x = x.permute(1, 0, 2)
+
+        language_dict_features["hidden"] = x
+        features_dict = {"visual": inputs["visual"],
+                         "lang": language_dict_features
+                         }
+        return features_dict
+
+
+class DummyLayer(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, inputs):
+        return inputs
+
+
+class VLFuse(torch.nn.Module):
+    """
+    Early Fusion Module
+    """
+
+    def __init__(self, cfg):
+        super(VLFuse, self).__init__()
+        self.init_configs(cfg)
+        self.cfg = cfg
+
+        self.use_checkpoint = False
+        if hasattr(cfg.MODEL.DYHEAD, 'USE_CHECKPOINT'):
+            self.use_checkpoint = cfg.MODEL.DYHEAD.USE_CHECKPOINT
+            self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
+
+        # early fusion module
+        print("EARLY FUSION ON, USING {}".format(cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE))
+        if cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S":
+            # single-direction (text->image)
+            # text -> image
+            self.t2i_attn = AttentionT2I(q_dim=self.joint_embedding_size,
+                                           k_dim=self.lang_dim,
+                                           embed_dim=self.embed_dim,
+                                           num_heads=self.n_head,
+                                           hidden_dim=self.t2i_hidden_dim,
+                                           dropout=0.1,
+                                           drop_path=.0,
+                                           init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS,
+                                           mode="t2i",
+                                           use_layer_scale=cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE,
+                                           clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW,
+                                           clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW
+                                           )
+
+        elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B":
+            # bi-direction (text->image, image->text)
+            self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.joint_embedding_size,
+                        l_dim=self.lang_dim,
+                        embed_dim=self.embed_dim,
+                        num_heads=self.n_head,
+                        hidden_dim=self.i2t_hidden_dim,
+                        dropout=0.1,
+                        drop_path=.0,
+                        init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS,
+                        cfg=cfg
+                        )
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT:
+                self.shrink_lang = FeatureResizer(self.lang_dim * 5,
+                                self.lang_dim, 0.1)
+
+
+        elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN":
+            # single-direction (text->image)
+            self.mapping_lang = _make_mlp(self.lang_dim,
+                                          self.joint_embedding_size,
+                                          self.joint_embedding_dropout)
+            self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \
+                                               for _ in range(5)])
+
+        elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM":
+            # single-direction (text->image)
+            self.mapping_lang = _make_mlp(self.lang_dim,
+                                          self.joint_embedding_size,
+                                          self.joint_embedding_dropout)
+            self.gamma = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5))
+            self.beta = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5))
+
+            self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \
+                                               for _ in range(5)])
+
+        else:
+            print("NO FUSION INVOLVED.")
+
+    def init_configs(self, cfg):
+        # common params
+        self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
+        self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE
+        self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT
+        self.joint_mlp_layers = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS
+
+        self.max_query_len = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
+        self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
+        self.coord_dim = 8
+        self.joint_inp_dim = self.coord_dim + self.joint_embedding_size
+        self.joint_out_dim = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE
+
+        # mha params
+        self.n_head = 8
+        self.embed_dim = 2048
+        self.t2i_hidden_dim = 1024  # 256 * 4
+        self.i2t_hidden_dim = 3072  # 768 * 4
+
+        if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]:
+            self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
+        else:
+            self.lang_dim = 1024
+
+    def forward(self, x):
+        visual_features = x["visual"]
+        language_dict_features = x["lang"]
+
+        batch_size = visual_features[0].shape[0]
+        device = visual_features[0].device
+
+        fused_visual_features = None
+        fused_language_dict_features = None
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S":
+            language_feature = language_dict_features['hidden']
+            mask = language_dict_features['masks']
+            # text -> image
+            if self.use_checkpoint:
+                q0, q1, q2, q3, q4 = checkpoint.checkpoint(
+                    self.t2i_attn,
+                    visual_features[0], visual_features[1],
+                    visual_features[2], visual_features[3],
+                    visual_features[4],
+                    language_feature, language_feature,
+                    mask,
+                    self.dummy_tensor
+                )
+            else:
+                q0, q1, q2, q3, q4 = self.t2i_attn(
+                    visual_features[0], visual_features[1],
+                    visual_features[2], visual_features[3],
+                    visual_features[4],
+                    language_feature, language_feature,
+                    attention_mask=mask
+                )
+
+            fused_visual_features = [q0, q1, q2, q3, q4]
+            fused_language_dict_features = language_dict_features
+
+        elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B":
+            if self.use_checkpoint:
+                q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = checkpoint.checkpoint(self.b_attn,
+                    visual_features[0], visual_features[1],
+                    visual_features[2], visual_features[3],
+                    visual_features[4],
+                    language_dict_features['hidden'],
+                    language_dict_features['masks'],
+                    self.dummy_tensor
+                )
+            else:
+                q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = self.b_attn(
+                    visual_features[0], visual_features[1],
+                    visual_features[2], visual_features[3],
+                    visual_features[4],
+                    language_dict_features['hidden'],
+                    language_dict_features['masks'],
+                    self.dummy_tensor
+                )
+
+            fused_visual_features = [q0, q1, q2, q3, q4]
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT:
+                language_features = self.shrink_lang(torch.cat([l0, l1, l2, l3, l4], dim = -1))
+            else:
+                language_features = l0
+
+            language_dict_features['hidden'] = language_features
+            fused_language_dict_features = language_dict_features
+
+        elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN":
+            # text -> image
+            language_feature = language_dict_features['aggregate']
+            language_feature = self.mapping_lang(language_feature)
+            visu_feat = []
+            for ii, feat in enumerate(visual_features):
+                attn_feat = func_attention(feat, language_feature, smooth=1, raw_feature_norm="softmax")
+                visu_feat.append(attn_feat)
+
+            fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)]
+            fused_language_dict_features = language_dict_features
+
+        elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM":
+            # text -> image
+            # relative position embedding
+            coord_feats = [_make_coord(batch_size, x.shape[2], x.shape[3]) for x in visual_features]
+            # I only use a global representation of language
+            # you can also use more complex modeling using word-level representations
+            # Usage: lang_feat = lang_feat['words'] shape [seq_len, dim]
+            language_feature = language_dict_features['aggregate']
+            language_feature = self.mapping_lang(language_feature)
+
+            # attention mechanism for fusion
+            gamma = [F.tanh(gamma(language_feature)) for gamma in self.gamma]
+            beta = [F.tanh(beta(language_feature)) for beta in self.beta]
+
+            visu_feat = []
+            for ii, feat in enumerate(visual_features):
+                coord_feat = coord_feats[ii].to(device)
+                feat = torch.cat([feat, coord_feat], dim=1)
+                b = beta[ii].view(batch_size, -1, 1, 1).expand_as(feat)
+                g = gamma[ii].view(batch_size, -1, 1, 1).expand_as(feat)
+                feat = F.relu(g * feat + b)
+                visu_feat.append(feat)
+
+            fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)]
+            fused_language_dict_features = language_dict_features
+
+        else:
+            fused_visual_features = visual_features
+            fused_language_dict_features = language_dict_features
+
+        features_dict = {"visual": fused_visual_features,
+                         "lang": fused_language_dict_features}
+
+        return features_dict
+
+
+class VLDyHead(torch.nn.Module):
+    def __init__(self, cfg):
+        super(VLDyHead, self).__init__()
+        self.cfg = cfg
+        # bert_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)
+        if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased":
+            lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)
+        elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
+            lang_cfg = cfg
+        else:
+            lang_cfg = None
+            raise NotImplementedError
+
+        num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1
+        num_tokens = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
+        num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
+        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
+        channels = cfg.MODEL.DYHEAD.CHANNELS
+
+        if cfg.MODEL.DYHEAD.USE_GN:
+            bn_type = ['gn', cfg.MODEL.GROUP_NORM.NUM_GROUPS]
+        elif cfg.MODEL.DYHEAD.USE_NSYNCBN:
+            bn_type = 'nsbn'
+        elif cfg.MODEL.DYHEAD.USE_SYNCBN:
+            bn_type = 'sbn'
+        else:
+            bn_type = None
+
+        use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU
+        use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE
+        use_deform = cfg.MODEL.DYHEAD.USE_DFCONV
+
+        if cfg.MODEL.DYHEAD.CONV_FUNC:
+            conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type)
+        else:
+            conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type)
+
+        dyhead_tower = []
+        for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
+            if cfg.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON:
+                # cross-modality fusion
+                dyhead_tower.append(
+                    VLFuse(cfg)
+                )
+                # self language path
+                if i < cfg.MODEL.DYHEAD.NUM_CONVS - 1 or cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT:
+                    # dyhead_tower.append(
+                    #     BertEncoderLayer(
+                    #     bert_cfg,
+                    #     clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW,
+                    #     clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW)
+                    # )
+                    if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased":
+                        dyhead_tower.append(
+                            BertEncoderLayer(
+                                lang_cfg,
+                                clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW,
+                                clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW)
+                        )
+                    elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
+                        dyhead_tower.append(
+                            CLIPTransformerLayer(lang_cfg)
+                        )
+                    else:
+                        raise NotImplementedError
+
+                else:
+                    dyhead_tower.append(
+                        DummyLayer()
+                    )
+
+            # self vision path
+            dyhead_tower.append(
+                DyConv(
+                    in_channels if i == 0 else channels,
+                    channels,
+                    conv_func=conv_func,
+                    use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu,
+                    use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse,
+                    use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform,
+                )
+            )
+
+        self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
+
+        self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1)
+        self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1)
+        self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1)
+
+        # initialize the bias for focal loss
+        prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+
+        log_scale = self.cfg.MODEL.DYHEAD.LOG_SCALE
+
+        # soft token head
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+            self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1)
+            # ABLATION
+            # self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1, bias=False)
+            # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
+            # self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True)
+
+        # contrastive alignment head
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS == False
+            contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM
+            self.contrastive_align_projection_image = nn.Conv2d(channels, num_anchors * contrastive_hdim, kernel_size=1)
+            self.contrastive_align_projection_text = nn.Linear(channels, contrastive_hdim, bias=True)
+            self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
+
+        # dot product soft token head
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS == False
+            self.dot_product_projection_image = nn.Identity()
+            self.dot_product_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM,
+                                                         num_anchors * channels, bias=True)
+            self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
+            # DEBUG
+            # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
+            self.bias_lang = nn.Parameter(torch.zeros(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True)
+            self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True)
+
+        # initialization
+        for modules in [self.cls_logits, self.bbox_pred,
+                        self.centerness]:
+            for l in modules.modules():
+                if isinstance(l, nn.Conv2d):
+                    torch.nn.init.normal_(l.weight, std=0.01)
+                    torch.nn.init.constant_(l.bias, 0)
+
+        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
+
+        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
+
+        # if use soft token loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+            for modules in [self.token_logits]:
+                for l in modules.modules():
+                    if isinstance(l, nn.Conv2d):
+                        torch.nn.init.normal_(l.weight, std=0.01)
+                        torch.nn.init.constant_(l.bias, 0)
+
+            torch.nn.init.constant_(self.token_logits.bias, bias_value)
+            # print(torch.norm(self.token_logits.weight))
+
+        # if use contrastive loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            for modules in [self.contrastive_align_projection_image]:
+                for l in modules.modules():
+                    if isinstance(l, nn.Conv2d):
+                        torch.nn.init.normal_(l.weight, std=0.01)
+                        torch.nn.init.constant_(l.bias, 0)
+
+        # if use dot product token loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            for modules in [self.dot_product_projection_image]:
+                for l in modules.modules():
+                    if isinstance(l, nn.Conv2d):
+                        torch.nn.init.normal_(l.weight, std=0.01)
+                        torch.nn.init.constant_(l.bias, bias_value)
+        
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+            if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
+                lang_cfg = BertConfig.from_pretrained("bert-base-uncased")
+                lang_cfg.hidden_size = cfg.MODEL.CLIP.WIDTH
+                lang_cfg.vocab_size = cfg.MODEL.CLIP.VOCAB_SIZE
+            self.mlm_head = BertLMPredictionHead(
+                lang_cfg
+            ) #nn.Linear(hidden_size, config.vocab_size, bias=False)
+
+    def forward(self, x, language_dict_features=None, embedding=None, swint_feature_c4=None):
+        logits = []
+        bbox_reg = []
+        centerness = []
+
+        feat_inputs = {"visual": x,
+                       "lang": language_dict_features}
+
+        dyhead_tower = self.dyhead_tower(feat_inputs)
+
+        # soft token
+        t_logits = None
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+            t_logits = []
+        
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT:
+            embedding = dyhead_tower["lang"]["hidden"]
+        
+        # MLM loss
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
+            mlm_logits = self.mlm_head(embedding)
+        else:
+            mlm_logits = None
+
+        # contrastive
+        contrastive_logits = None
+        proj_tokens = None
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            contrastive_logits = []
+            # follow MDETR's way
+            proj_tokens = F.normalize(
+                self.contrastive_align_projection_text(embedding), p=2, dim=-1
+            )
+
+        # dot product soft token
+        dot_product_logits = None
+        dot_product_proj_tokens = None
+        dot_product_proj_tokens_bias = None
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            dot_product_logits = []
+            # norm
+            embedding = F.normalize(embedding, p=2, dim=-1)
+            dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0)
+            # w/o norm
+            # dot_product_proj_tokens = self.dot_product_projection_text(embedding / 28.0)
+
+            dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) + self.bias0
+
+        # shallow contrastive (original feature from image & text encoder)
+        shallow_img_emb_feats = None
+        shallow_text_emb = None
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \
+                or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+            shallow_img_emb_feats = []
+            shallow_text_emb = embedding
+
+        # print([v.shape for v in x])
+        # shallow contrastive: use the feature from swint backbone
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+            for b, feature in enumerate(swint_feature_c4):
+                # BF, CF, HF, WF = feat.shape
+                # shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF)
+                shallow_img_emb_feats.append(feature)
+
+        fused_visual_features = None
+        if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
+            fused_visual_features = []
+
+        # use the feature from FPN
+        for l, feature in enumerate(x):
+            logits.append(self.cls_logits(dyhead_tower["visual"][l]))
+
+            bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower["visual"][l]))
+            bbox_reg.append(bbox_pred)
+
+            centerness.append(self.centerness(dyhead_tower["visual"][l]))
+
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+                t_logits.append(self.token_logits(dyhead_tower["visual"][l]))
+
+                # ABLATION
+                # b = self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+                # x = dyhead_tower["visual"][l]
+                # B, C, H, W = x.shape
+                # bias = b.repeat(B, 1, H, W)
+                # t_logits.append(self.token_logits(dyhead_tower["visual"][l] + bias) + self.bias0)
+
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+                x = dyhead_tower["visual"][l]
+                B, _, H, W = x.shape
+                C = proj_tokens.shape[2]
+                proj_queries = self.contrastive_align_projection_image(dyhead_tower["visual"][l])
+                proj_queries = permute_and_flatten(proj_queries, B, -1, C, H, W)
+                normalized_img_emb = F.normalize(proj_queries, p=2, dim=-1)
+                normalized_text_emb = proj_tokens
+                contrastive_logit = (
+                        torch.matmul(normalized_img_emb, normalized_text_emb.transpose(-1, -2)) / self.log_scale.exp())
+                contrastive_logits.append(contrastive_logit)
+
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+                x = dyhead_tower["visual"][l]
+                if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
+                    fused_visual_features.append(x)
+                B, C, H, W = x.shape
+
+                # add bias (language)
+                dot_product_proj_queries = self.dot_product_projection_image(x)
+                dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W)
+
+                A = dot_product_proj_queries.shape[1]
+                bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1)
+
+                dot_product_logit = (torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp()) + bias
+                if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT:
+                    dot_product_logit = torch.clamp(dot_product_logit, max=50000)
+                    dot_product_logit = torch.clamp(dot_product_logit, min=-50000)
+                dot_product_logits.append(dot_product_logit)
+
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
+                feat = feature
+                BF, CF, HF, WF = feat.shape
+                shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF)
+                shallow_img_emb_feats.append(shallow_img_emb)
+
+        # no matter the feature is from backboone or from fpn, we use shallow_img_embs all the time
+        if shallow_img_emb_feats is not None and shallow_text_emb is not None:
+            # shallow_img_embs = torch.cat(shallow_img_embs, dim=1)
+            proj_tokens = shallow_text_emb
+        return logits, bbox_reg, centerness, t_logits, proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features
+
+
+class VLDyHeadModule(torch.nn.Module):
+
+    def __init__(self, cfg):
+        super(VLDyHeadModule, self).__init__()
+        self.cfg = cfg
+        self.head = VLDyHead(cfg)
+        box_coder = BoxCoder(cfg)
+        self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder)
+        self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True)
+        self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False)
+        self.anchor_generator = make_anchor_generator_complex(cfg)
+
+        self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
+        self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE
+        self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT
+        if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]:
+            self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
+        else:
+            self.lang_dim = 1024
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            self.resizer = FeatureResizer(
+                input_feat_size=self.lang_dim,
+                output_feat_size=self.joint_embedding_size,
+                dropout=self.joint_embedding_dropout
+            )
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER:
+            self.tunable_linear = torch.nn.Linear(self.lang_dim, 1000, bias=False)
+            self.tunable_linear.weight.data.fill_(0.0)
+
+    def forward(self, images, features, targets=None,
+                language_dict_features=None,
+                positive_map=None,
+                captions=None,
+                swint_feature_c4=None
+                ):
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            # resizer needed
+            embedding = language_dict_features['embedded']
+            embedding = self.resizer(embedding)
+        elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            # no resizer needed
+            embedding = language_dict_features['embedded']
+        else:
+            embedding = None
+
+        if "masks" in language_dict_features:
+            text_masks = language_dict_features["masks"]
+        else:
+            text_masks = None
+        
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER:
+            embedding = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + embedding
+            language_dict_features['embedded'] = embedding
+            language_dict_features['hidden'] = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + language_dict_features['hidden']
+
+        box_cls, box_regression, centerness, token_logits, \
+        proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features = self.head(features,
+                                                                        language_dict_features,
+                                                                        embedding,
+                                                                        swint_feature_c4
+                                                                        )
+        anchors = self.anchor_generator(images, features)
+
+        if self.training:
+            return self._forward_train(box_cls, box_regression, centerness, targets, anchors,
+                                       captions,
+                                       positive_map,
+                                       token_logits,
+                                       proj_tokens,
+                                       contrastive_logits,
+                                       dot_product_logits,
+                                       text_masks,
+                                       mlm_logits = mlm_logits,
+                                       mlm_labels = language_dict_features["mlm_labels"],
+                                       shallow_img_emb_feats=shallow_img_emb_feats,
+                                       fused_visual_features=fused_visual_features
+                                       )
+        else:
+            return self._forward_test(box_regression, centerness, anchors,
+                                      box_cls,
+                                      token_logits,
+                                      dot_product_logits,
+                                      positive_map,
+                                      fused_visual_features=fused_visual_features
+                                      )
+
+    def _forward_train(self, box_cls, box_regression, centerness, targets, anchors,
+                       captions=None,
+                       positive_map=None,
+                       token_logits=None,
+                       proj_tokens=None,
+                       contrastive_logits=None,
+                       dot_product_logits=None,
+                       text_masks=None,
+                       mlm_logits=None,
+                       mlm_labels=None,
+                       shallow_img_emb_feats=None,
+                       fused_visual_features=None
+                       ):
+
+        loss_box_cls, loss_box_reg, loss_centerness, loss_token, loss_contrastive_align, loss_dot_product_token, loss_shallow_contrastive = self.loss_evaluator(
+            box_cls, box_regression, centerness, targets, anchors,
+            captions,
+            positive_map,
+            token_logits,
+            proj_tokens,
+            contrastive_logits,
+            dot_product_logits,
+            text_masks,
+            shallow_img_emb_feats
+        )
+
+        losses = {
+            # "loss_cls": loss_box_cls,
+            "loss_reg": loss_box_reg,
+            "loss_centerness": loss_centerness
+        }
+
+        if mlm_labels is not None and mlm_logits is not None:
+            losses["mlm_loss"] = nn.CrossEntropyLoss(ignore_index = -100)(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1)) * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS:
+            losses["loss_cls"] = loss_box_cls
+        else:
+            losses["loss_cls"] = 0.0 * loss_box_cls
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
+            losses["loss_token"] = loss_token * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
+            losses["loss_contrastive_align"] = loss_contrastive_align * \
+                                               self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
+            losses["loss_dot_product_token"] = loss_dot_product_token * \
+                                               self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \
+                self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
+            losses["loss_shallow_contrastive"] = loss_shallow_contrastive * \
+                                                 self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT
+
+        if self.cfg.MODEL.RPN_ONLY:
+            return None, losses, None
+        else:
+            # Let's just use one image per batch
+            assert (box_regression[0].shape[0]) == 1
+            positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=1)
+            boxes = self.box_selector_train(box_regression, centerness, anchors,
+                                        box_cls,
+                                        token_logits,
+                                        dot_product_logits,
+                                        positive_map=positive_map_label_to_token
+                                        )
+            train_boxes = []
+            for b, t in zip(boxes, targets):
+                tb = t.copy_with_fields(["labels"])
+                tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device))
+                train_boxes.append(cat_boxlist([b, tb]))
+            return train_boxes, losses, fused_visual_features
+
+    def _forward_test(self, box_regression, centerness, anchors,
+                      box_cls=None,
+                      token_logits=None,
+                      dot_product_logits=None,
+                      positive_map=None,
+                      fused_visual_features=None
+                      ):
+
+        boxes = self.box_selector_test(box_regression, centerness, anchors,
+                                       box_cls,
+                                       token_logits,
+                                       dot_product_logits,
+                                       positive_map,
+                                       )
+        return boxes, {}, fused_visual_features
diff --git a/maskrcnn_benchmark/modeling/utils.py b/maskrcnn_benchmark/modeling/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2834b105c8d171438a4534eb17fc0da65154d610
--- /dev/null
+++ b/maskrcnn_benchmark/modeling/utils.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""
+Miscellaneous utility functions
+"""
+
+import torch
+
+
+def cat(tensors, dim=0):
+    """
+    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
+    """
+    assert isinstance(tensors, (list, tuple))
+    if len(tensors) == 1:
+        return tensors[0]
+    return torch.cat(tensors, dim)
+
+
+def permute_and_flatten(layer, N, A, C, H, W):
+    layer = layer.view(N, -1, C, H, W)
+    layer = layer.permute(0, 3, 4, 1, 2)
+    layer = layer.reshape(N, -1, C)
+    return layer
+
+
+def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None):
+    box_regression_flattened = []
+    box_cls_flattened = []
+    token_logit_flattened = []
+
+    # for each feature level, permute the outputs to make them be in the
+    # same format as the labels. Note that the labels are computed for
+    # all feature levels concatenated, so we keep the same representation
+    # for the objectness and the box_regression
+    for box_cls_per_level, box_regression_per_level in zip(
+            box_cls, box_regression
+    ):
+        N, AxC, H, W = box_cls_per_level.shape
+        Ax4 = box_regression_per_level.shape[1]
+        A = Ax4 // 4
+        C = AxC // A
+        box_cls_per_level = permute_and_flatten(
+            box_cls_per_level, N, A, C, H, W
+        )
+        box_cls_flattened.append(box_cls_per_level)
+
+        box_regression_per_level = permute_and_flatten(
+            box_regression_per_level, N, A, 4, H, W
+        )
+        box_regression_flattened.append(box_regression_per_level)
+
+    if token_logits is not None:
+        for token_logit_per_level in token_logits:
+            N, AXT, H, W = token_logit_per_level.shape
+            T = AXT // A
+            token_logit_per_level = permute_and_flatten(
+                token_logit_per_level, N, A, T, H, W
+            )
+            token_logit_flattened.append(token_logit_per_level)
+
+    # concatenate on the first dimension (representing the feature levels), to
+    # take into account the way the labels were generated (with all feature maps
+    # being concatenated as well)
+    box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C)
+    box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
+
+    token_logits_stacked = None
+    if token_logits is not None:
+        # stacked
+        token_logits_stacked = cat(token_logit_flattened, dim=1)
+
+    return box_regression, box_cls, token_logits_stacked
+
+
+def round_channels(channels, divisor=8):
+    rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor)
+    if float(rounded_channels) < 0.9 * channels:
+        rounded_channels += divisor
+    return rounded_channels
diff --git a/maskrcnn_benchmark/solver/__init__.py b/maskrcnn_benchmark/solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..927668ea6f35aedcff25f779e85a8b8c27a8c797
--- /dev/null
+++ b/maskrcnn_benchmark/solver/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from .build import make_optimizer
+from .build import make_lr_scheduler
+from .lr_scheduler import WarmupMultiStepLR
diff --git a/maskrcnn_benchmark/solver/build.py b/maskrcnn_benchmark/solver/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..4456f914f2349d3a86642871161e95e4cd26af7d
--- /dev/null
+++ b/maskrcnn_benchmark/solver/build.py
@@ -0,0 +1,116 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+import itertools
+
+from .lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR, WarmupReduceLROnPlateau
+
+
+def make_optimizer(cfg, model):
+    def maybe_add_full_model_gradient_clipping(optim):  # optim: the optimizer class
+        # detectron2 doesn't have full model gradient clipping now
+        clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
+        enable = (
+                cfg.SOLVER.CLIP_GRADIENTS.ENABLED
+                and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
+                and clip_norm_val > 0.0
+        )
+
+        class FullModelGradientClippingOptimizer(optim):
+            def step(self, closure=None):
+                all_params = itertools.chain(*[x["params"] for x in self.param_groups])
+                torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
+                super().step(closure=closure)
+
+        return FullModelGradientClippingOptimizer if enable else optim
+
+    params = []
+    for key, value in model.named_parameters():
+        if not value.requires_grad:
+            continue
+        lr = cfg.SOLVER.BASE_LR
+        weight_decay = cfg.SOLVER.WEIGHT_DECAY
+
+        # different lr schedule
+        if "language_backbone" in key:
+            lr = cfg.SOLVER.LANG_LR
+
+        if "backbone.body" in key and "language_backbone.body" not in key:
+            lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BACKBONE_BODY_LR_FACTOR
+
+        if "bias" in key:
+            lr *= cfg.SOLVER.BIAS_LR_FACTOR
+            weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
+
+        if 'norm' in key or 'Norm' in key:
+            weight_decay *= cfg.SOLVER.WEIGHT_DECAY_NORM_FACTOR
+            print("Setting weight decay of {} to {}".format(key, weight_decay))
+
+        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
+
+    if cfg.SOLVER.OPTIMIZER == "SGD":
+        optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(params, lr, momentum=cfg.SOLVER.MOMENTUM)
+    elif cfg.SOLVER.OPTIMIZER == "ADAMW":
+        optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(params, lr)
+
+    return optimizer
+
+
+def make_lr_scheduler(cfg, optimizer):
+    if cfg.SOLVER.MULTI_MAX_EPOCH:
+        assert len(cfg.SOLVER.MULTI_MAX_EPOCH) == len(cfg.SOLVER.STEPS)
+        lr_scheduler = []
+
+        for stage_step, stage_max_epoch in zip(cfg.SOLVER.STEPS, cfg.SOLVER.MULTI_MAX_ITER):
+            milestones = []
+            for step in stage_step:
+                milestones.append(round(step * stage_max_epoch))
+            lr_scheduler.append(WarmupMultiStepLR(optimizer,
+                                                  milestones,
+                                                  cfg.SOLVER.GAMMA,
+                                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+                                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+                                                  warmup_method=cfg.SOLVER.WARMUP_METHOD, )
+                                )
+        return lr_scheduler
+
+    elif cfg.SOLVER.USE_COSINE:
+        max_iters = cfg.SOLVER.MAX_ITER
+        return WarmupCosineAnnealingLR(
+            optimizer,
+            max_iters,
+            cfg.SOLVER.GAMMA,
+            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+            warmup_method=cfg.SOLVER.WARMUP_METHOD,
+            eta_min=cfg.SOLVER.MIN_LR
+        )
+
+    elif cfg.SOLVER.USE_AUTOSTEP:
+        max_iters = cfg.SOLVER.MAX_ITER
+        return WarmupReduceLROnPlateau(
+            optimizer,
+            max_iters,
+            cfg.SOLVER.GAMMA,
+            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+            warmup_method=cfg.SOLVER.WARMUP_METHOD,
+            eta_min=cfg.SOLVER.MIN_LR,
+            patience=cfg.SOLVER.STEP_PATIENCE,
+            verbose=True
+        )
+
+    else:
+        milestones = []
+        for step in cfg.SOLVER.STEPS:
+            if step < 1:
+                milestones.append(round(step * cfg.SOLVER.MAX_ITER))
+            else:
+                milestones.append(step)
+        return WarmupMultiStepLR(
+            optimizer,
+            milestones,
+            cfg.SOLVER.GAMMA,
+            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+            warmup_method=cfg.SOLVER.WARMUP_METHOD,
+        )
diff --git a/maskrcnn_benchmark/solver/lr_scheduler.py b/maskrcnn_benchmark/solver/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06a52d57a1da5433c06555a551753dfe38a0fa8
--- /dev/null
+++ b/maskrcnn_benchmark/solver/lr_scheduler.py
@@ -0,0 +1,164 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from bisect import bisect_right
+
+import math
+import torch
+
+
+# FIXME ideally this would be achieved with a CombinedLRScheduler,
+# separating MultiStepLR with WarmupLR
+# but the current LRScheduler design doesn't allow it
+class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
+    def __init__(
+        self,
+        optimizer,
+        milestones,
+        gamma=0.1,
+        warmup_factor=1.0 / 3,
+        warmup_iters=500,
+        warmup_method="linear",
+        last_epoch=-1,
+    ):
+        if not list(milestones) == sorted(milestones):
+            raise ValueError(
+                "Milestones should be a list of" " increasing integers. Got {}",
+                milestones,
+            )
+
+        if warmup_method not in ("constant", "linear"):
+            raise ValueError(
+                "Only 'constant' or 'linear' warmup_method accepted"
+                "got {}".format(warmup_method)
+            )
+        self.milestones = milestones
+        self.gamma = gamma
+        self.warmup_factor = warmup_factor
+        self.warmup_iters = warmup_iters
+        self.warmup_method = warmup_method
+        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        warmup_factor = 1
+        if self.last_epoch < self.warmup_iters:
+            if self.warmup_method == "constant":
+                warmup_factor = self.warmup_factor
+            elif self.warmup_method == "linear":
+                alpha = float(self.last_epoch) / self.warmup_iters
+                warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+        return [
+            base_lr
+            * warmup_factor
+            * self.gamma ** bisect_right(self.milestones, self.last_epoch)
+            for base_lr in self.base_lrs
+        ]
+
+
+class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
+    def __init__(
+            self,
+            optimizer,
+            max_iters,
+            gamma=0.1,
+            warmup_factor=1.0 / 3,
+            warmup_iters=500,
+            warmup_method="linear",
+            eta_min = 0,
+            last_epoch=-1,
+    ):
+
+        if warmup_method not in ("constant", "linear"):
+            raise ValueError(
+                "Only 'constant' or 'linear' warmup_method accepted"
+                "got {}".format(warmup_method)
+            )
+        self.max_iters = max_iters
+        self.gamma = gamma
+        self.warmup_factor = warmup_factor
+        self.warmup_iters = warmup_iters
+        self.warmup_method = warmup_method
+        self.eta_min = eta_min
+        super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        warmup_factor = 1
+
+        if self.last_epoch < self.warmup_iters:
+            if self.warmup_method == "constant":
+                warmup_factor = self.warmup_factor
+            elif self.warmup_method == "linear":
+                alpha = float(self.last_epoch) / self.warmup_iters
+                warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+            return [
+                base_lr
+                * warmup_factor
+                for base_lr in self.base_lrs
+            ]
+        else:
+            return [
+                self.eta_min
+                + (base_lr - self.eta_min)
+                * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_iters) / self.max_iters)) / 2
+                for base_lr in self.base_lrs
+            ]
+
+class WarmupReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
+    def __init__(
+            self,
+            optimizer,
+            max_iters,
+            gamma=0.1,
+            warmup_factor=1.0 / 3,
+            warmup_iters=500,
+            warmup_method="linear",
+            eta_min = 0,
+            last_epoch=-1,
+            patience = 5,
+            verbose = False,
+    ):    
+
+        if warmup_method not in ("constant", "linear"):
+            raise ValueError(
+                "Only 'constant' or 'linear' warmup_method accepted"
+                "got {}".format(warmup_method)
+            )
+        self.warmup_factor = warmup_factor
+        self.warmup_iters = warmup_iters
+        self.warmup_method = warmup_method
+        self.eta_min = eta_min
+
+        if last_epoch == -1:
+            for group in optimizer.param_groups:
+                group.setdefault('initial_lr', group['lr'])
+        else:
+            for i, group in enumerate(optimizer.param_groups):
+                if 'initial_lr' not in group:
+                    raise KeyError("param 'initial_lr' is not specified "
+                                   "in param_groups[{}] when resuming an optimizer".format(i))
+        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+        super(WarmupReduceLROnPlateau, self).__init__(optimizer, factor=gamma, patience=patience, mode='max', min_lr=eta_min, verbose = verbose)
+
+    def step(self, metrics=None):
+        warmup_factor = 1
+
+        if self.last_epoch < self.warmup_iters:
+            if self.warmup_method == "constant":
+                warmup_factor = self.warmup_factor
+            elif self.warmup_method == "linear":
+                alpha = float(self.last_epoch) / self.warmup_iters
+                warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+            
+            if self.last_epoch >= self.warmup_iters-1:
+                warmup_factor = 1.0
+                
+            warmup_lrs = [
+                base_lr
+                * warmup_factor
+                for base_lr in self.base_lrs
+            ]
+
+            for param_group, lr in zip(self.optimizer.param_groups, warmup_lrs):
+                param_group['lr'] = lr
+            
+            self.last_epoch += 1
+        elif metrics:
+            super().step(metrics)
\ No newline at end of file
diff --git a/maskrcnn_benchmark/structures/__init__.py b/maskrcnn_benchmark/structures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b04683086ffad2345aed97b08d0c11ac385ba85
--- /dev/null
+++ b/maskrcnn_benchmark/structures/bounding_box.py
@@ -0,0 +1,321 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+# transpose
+FLIP_LEFT_RIGHT = 0
+FLIP_TOP_BOTTOM = 1
+
+
+class BoxList(object):
+    """
+    This class represents a set of bounding boxes.
+    The bounding boxes are represented as a Nx4 Tensor.
+    In order to uniquely determine the bounding boxes with respect
+    to an image, we also store the corresponding image dimensions.
+    They can contain extra information that is specific to each bounding box, such as
+    labels.
+    """
+
+    def __init__(self, bbox, image_size, mode="xyxy"):
+        device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
+        # only do as_tensor if isn't a "no-op", because it hurts JIT tracing
+        if (not isinstance(bbox, torch.Tensor)
+                or bbox.dtype != torch.float32 or bbox.device != device):
+            bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
+        if bbox.ndimension() != 2:
+            raise ValueError(
+                "bbox should have 2 dimensions, got {}".format(bbox.ndimension())
+            )
+        if bbox.size(-1) != 4:
+            raise ValueError(
+                "last dimenion of bbox should have a "
+                "size of 4, got {}".format(bbox.size(-1))
+            )
+        if mode not in ("xyxy", "xywh"):
+            raise ValueError("mode should be 'xyxy' or 'xywh'")
+
+        self.bbox = bbox
+        self.size = image_size  # (image_width, image_height)
+        self.mode = mode
+        self.extra_fields = {}
+
+    # note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between
+    def _jit_unwrap(self):
+        return (self.bbox,) + tuple(f for f in (self.get_field(field)
+                                    for field in sorted(self.fields()))
+                                    if isinstance(f, torch.Tensor))
+
+    def _jit_wrap(self, input_stream):
+        self.bbox = input_stream[0]
+        num_consumed = 1
+        for f in sorted(self.fields()):
+            if isinstance(self.extra_fields[f], torch.Tensor):
+                self.extra_fields[f] = input_stream[num_consumed]
+                num_consumed += 1
+        return self, input_stream[num_consumed:]
+
+    def add_field(self, field, field_data):
+        self.extra_fields[field] = field_data
+
+    def get_field(self, field):
+        return self.extra_fields[field]
+
+    def has_field(self, field):
+        return field in self.extra_fields
+
+    def fields(self):
+        return list(self.extra_fields.keys())
+
+    def _copy_extra_fields(self, bbox):
+        for k, v in bbox.extra_fields.items():
+            self.extra_fields[k] = v
+
+    def convert(self, mode):
+        if mode not in ("xyxy", "xywh"):
+            raise ValueError("mode should be 'xyxy' or 'xywh'")
+        if mode == self.mode:
+            return self
+        # we only have two modes, so don't need to check
+        # self.mode
+        xmin, ymin, xmax, ymax = self._split_into_xyxy()
+        if mode == "xyxy":
+            bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
+            bbox = BoxList(bbox, self.size, mode=mode)
+        else:
+            TO_REMOVE = 1
+            # NOTE: explicitly specify dim to avoid tracing error in GPU
+            bbox = torch.cat(
+                (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1
+            )
+            bbox = BoxList(bbox, self.size, mode=mode)
+        bbox._copy_extra_fields(self)
+        return bbox
+
+    def _split_into_xyxy(self):
+        if self.mode == "xyxy":
+            xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
+            return xmin, ymin, xmax, ymax
+        elif self.mode == "xywh":
+            TO_REMOVE = 1
+            xmin, ymin, w, h = self.bbox.split(1, dim=-1)
+            return (
+                xmin,
+                ymin,
+                xmin + (w - TO_REMOVE).clamp(min=0),
+                ymin + (h - TO_REMOVE).clamp(min=0),
+            )
+        else:
+            raise RuntimeError("Should not be here")
+
+    def resize(self, size, *args, **kwargs):
+        """
+        Returns a resized copy of this bounding box
+
+        :param size: The requested size in pixels, as a 2-tuple:
+            (width, height).
+        """
+
+        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
+        if ratios[0] == ratios[1]:
+            ratio = ratios[0]
+            scaled_box = self.bbox * ratio
+            bbox = BoxList(scaled_box, size, mode=self.mode)
+            # bbox._copy_extra_fields(self)
+            for k, v in self.extra_fields.items():
+                if not isinstance(v, torch.Tensor):
+                    v = v.resize(size, *args, **kwargs)
+                bbox.add_field(k, v)
+            return bbox
+
+        ratio_width, ratio_height = ratios
+        xmin, ymin, xmax, ymax = self._split_into_xyxy()
+        scaled_xmin = xmin * ratio_width
+        scaled_xmax = xmax * ratio_width
+        scaled_ymin = ymin * ratio_height
+        scaled_ymax = ymax * ratio_height
+        scaled_box = torch.cat(
+            (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
+        )
+        bbox = BoxList(scaled_box, size, mode="xyxy")
+        # bbox._copy_extra_fields(self)
+        for k, v in self.extra_fields.items():
+            if not isinstance(v, torch.Tensor):
+                v = v.resize(size, *args, **kwargs)
+            bbox.add_field(k, v)
+
+        return bbox.convert(self.mode)
+
+    def transpose(self, method):
+        """
+        Transpose bounding box (flip or rotate in 90 degree steps)
+        :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
+          :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
+          :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
+          :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
+        """
+        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
+            raise NotImplementedError(
+                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
+            )
+
+        image_width, image_height = self.size
+        xmin, ymin, xmax, ymax = self._split_into_xyxy()
+        if method == FLIP_LEFT_RIGHT:
+            TO_REMOVE = 1
+            transposed_xmin = image_width - xmax - TO_REMOVE
+            transposed_xmax = image_width - xmin - TO_REMOVE
+            transposed_ymin = ymin
+            transposed_ymax = ymax
+        elif method == FLIP_TOP_BOTTOM:
+            transposed_xmin = xmin
+            transposed_xmax = xmax
+            transposed_ymin = image_height - ymax
+            transposed_ymax = image_height - ymin
+
+        transposed_boxes = torch.cat(
+            (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
+        )
+        bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
+        # bbox._copy_extra_fields(self)
+        for k, v in self.extra_fields.items():
+            if not isinstance(v, torch.Tensor):
+                v = v.transpose(method)
+            bbox.add_field(k, v)
+        return bbox.convert(self.mode)
+
+    def crop(self, box):
+        """
+        Cropss a rectangular region from this bounding box. The box is a
+        4-tuple defining the left, upper, right, and lower pixel
+        coordinate.
+        """
+        xmin, ymin, xmax, ymax = self._split_into_xyxy()
+        w, h = box[2] - box[0], box[3] - box[1]
+        cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
+        cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
+        cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
+        cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
+
+        # TODO should I filter empty boxes here?
+        cropped_box = torch.cat(
+            (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
+        )
+        bbox = BoxList(cropped_box, (w, h), mode="xyxy")
+        # bbox._copy_extra_fields(self)
+        for k, v in self.extra_fields.items():
+            if not isinstance(v, torch.Tensor):
+                v = v.crop(box)
+            bbox.add_field(k, v)
+        return bbox.convert(self.mode)
+
+    # Tensor-like methods
+
+    def to(self, device):
+        bbox = BoxList(self.bbox.to(device), self.size, self.mode)
+        for k, v in self.extra_fields.items():
+            if hasattr(v, "to"):
+                v = v.to(device)
+            bbox.add_field(k, v)
+        return bbox
+
+    def __getitem__(self, item):
+        bbox = BoxList(self.bbox[item], self.size, self.mode)
+        for k, v in self.extra_fields.items():
+            bbox.add_field(k, v[item])
+        return bbox
+
+    def __len__(self):
+        return self.bbox.shape[0]
+
+    def clip_to_image(self, remove_empty=True):
+        TO_REMOVE = 1
+        x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE)
+        y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE)
+        x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE)
+        y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE)
+        self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1)
+        if remove_empty:
+            box = self.bbox
+            keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
+            return self[keep]
+        return self
+
+    def area(self):
+        if self.mode == 'xyxy':
+            TO_REMOVE = 1
+            box = self.bbox
+            area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
+        elif self.mode == 'xywh':
+            box = self.bbox
+            area = box[:, 2] * box[:, 3]
+        else:
+            raise RuntimeError("Should not be here")
+            
+        return area
+
+    def copy_with_fields(self, fields):
+        bbox = BoxList(self.bbox, self.size, self.mode)
+        if not isinstance(fields, (list, tuple)):
+            fields = [fields]
+        for field in fields:
+            bbox.add_field(field, self.get_field(field))
+        return bbox
+
+    def __repr__(self):
+        s = self.__class__.__name__ + "("
+        s += "num_boxes={}, ".format(len(self))
+        s += "image_width={}, ".format(self.size[0])
+        s += "image_height={}, ".format(self.size[1])
+        s += "mode={})".format(self.mode)
+        return s
+    
+    @staticmethod
+    def concate_box_list(list_of_boxes):
+        boxes = torch.cat([i.bbox for i in list_of_boxes], dim = 0)
+        extra_fields_keys = list(list_of_boxes[0].extra_fields.keys())
+        extra_fields = {}
+        for key in extra_fields_keys:
+            extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim = 0)
+
+        final = list_of_boxes[0].copy_with_fields(extra_fields_keys)
+
+        final.bbox = boxes
+        final.extra_fields = extra_fields
+        return final
+
+@torch.jit.unused
+def _onnx_clip_boxes_to_image(boxes, size):
+    # type: (Tensor, Tuple[int, int])
+    """
+    Clip boxes so that they lie inside an image of size `size`.
+    Clip's min max are traced as constants. Use torch.min/max to WAR this issue
+    Arguments:
+        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
+        size (Tuple[height, width]): size of the image
+    Returns:
+        clipped_boxes (Tensor[N, 4])
+    """
+    TO_REMOVE = 1
+    device = boxes.device
+    dim = boxes.dim()
+    boxes_x = boxes[..., 0::2]
+    boxes_y = boxes[..., 1::2]
+
+    boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device))
+    boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device))
+    boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device))
+    boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device))
+
+    clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
+    return clipped_boxes.reshape(boxes.shape)
+
+
+if __name__ == "__main__":
+    bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
+    s_bbox = bbox.resize((5, 5))
+    print(s_bbox)
+    print(s_bbox.bbox)
+
+    t_bbox = bbox.transpose(0)
+    print(t_bbox)
+    print(t_bbox.bbox)
diff --git a/maskrcnn_benchmark/structures/boxlist_ops.py b/maskrcnn_benchmark/structures/boxlist_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..85eb081ca64bf4464d6e523759b82882498bf4da
--- /dev/null
+++ b/maskrcnn_benchmark/structures/boxlist_ops.py
@@ -0,0 +1,184 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+from .bounding_box import BoxList
+
+from maskrcnn_benchmark.layers import nms as _box_nms
+from maskrcnn_benchmark.layers import ml_nms as _box_ml_nms
+
+
+def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"):
+    """
+    Performs non-maximum suppression on a boxlist, with scores specified
+    in a boxlist field via score_field.
+
+    Arguments:
+        boxlist(BoxList)
+        nms_thresh (float)
+        max_proposals (int): if > 0, then only the top max_proposals are kept
+            after non-maxium suppression
+        score_field (str)
+    """
+    if nms_thresh <= 0:
+        return boxlist
+    mode = boxlist.mode
+    boxlist = boxlist.convert("xyxy")
+    boxes = boxlist.bbox
+    score = boxlist.get_field(score_field)
+    keep = _box_nms(boxes, score, nms_thresh)
+    if max_proposals > 0:
+        keep = keep[: max_proposals]
+    boxlist = boxlist[keep]
+    return boxlist.convert(mode)
+
+
+def boxlist_ml_nms(boxlist, nms_thresh, max_proposals=-1,
+                   score_field="scores", label_field="labels"):
+    """
+    Performs non-maximum suppression on a boxlist, with scores specified
+    in a boxlist field via score_field.
+
+    Arguments:
+        boxlist(BoxList)
+        nms_thresh (float)
+        max_proposals (int): if > 0, then only the top max_proposals are kept
+            after non-maximum suppression
+        score_field (str)
+    """
+    if nms_thresh <= 0:
+        return boxlist
+    mode = boxlist.mode
+    boxlist = boxlist.convert("xyxy")
+    boxes = boxlist.bbox
+    scores = boxlist.get_field(score_field)
+    labels = boxlist.get_field(label_field)
+
+    if boxes.device==torch.device("cpu"):
+        keep = []
+        unique_labels = torch.unique(labels)
+        print(unique_labels)
+        for j in unique_labels:
+            inds = (labels == j).nonzero().view(-1)
+
+            scores_j = scores[inds]
+            boxes_j = boxes[inds, :].view(-1, 4)
+            keep_j = _box_nms(boxes_j, scores_j, nms_thresh)
+
+            keep += keep_j
+    else:
+        keep = _box_ml_nms(boxes, scores, labels.float(), nms_thresh)
+        
+    if max_proposals > 0:
+        keep = keep[: max_proposals]
+    boxlist = boxlist[keep]
+
+    return boxlist.convert(mode)
+
+
+def remove_small_boxes(boxlist, min_size):
+    """
+    Only keep boxes with both sides >= min_size
+
+    Arguments:
+        boxlist (Boxlist)
+        min_size (int)
+    """
+    # WORK AROUND: work around unbind using split + squeeze.
+    xywh_boxes = boxlist.convert("xywh").bbox
+    _, _, ws, hs = xywh_boxes.split(1, dim=1)
+    ws = ws.squeeze(1)
+    hs = hs.squeeze(1)
+    keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1)
+    return boxlist[keep]
+
+
+# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
+# with slight modifications
+def boxlist_iou(boxlist1, boxlist2):
+    """Compute the intersection over union of two set of boxes.
+    The box order must be (xmin, ymin, xmax, ymax).
+
+    Arguments:
+      box1: (BoxList) bounding boxes, sized [N,4].
+      box2: (BoxList) bounding boxes, sized [M,4].
+
+    Returns:
+      (tensor) iou, sized [N,M].
+
+    Reference:
+      https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
+    """
+    if boxlist1.size != boxlist2.size:
+        raise RuntimeError(
+                "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
+
+    N = len(boxlist1)
+    M = len(boxlist2)
+
+    area1 = boxlist1.area()
+    area2 = boxlist2.area()
+
+    box1, box2 = boxlist1.bbox, boxlist2.bbox
+
+    lt = torch.max(box1[:, None, :2], box2[:, :2])  # [N,M,2]
+    rb = torch.min(box1[:, None, 2:], box2[:, 2:])  # [N,M,2]
+
+    TO_REMOVE = 1
+
+    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    iou = inter / (area1[:, None] + area2 - inter)
+    return iou
+
+
+# TODO redundant, remove
+def _cat(tensors, dim=0):
+    """
+    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
+    """
+    assert isinstance(tensors, (list, tuple))
+    if len(tensors) == 1:
+        return tensors[0]
+    if isinstance(tensors[0], torch.Tensor):
+        return torch.cat(tensors, dim)
+    else:
+        return cat_boxlist(tensors)
+
+def cat_boxlist(bboxes):
+    """
+    Concatenates a list of BoxList (having the same image size) into a
+    single BoxList
+
+    Arguments:
+        bboxes (list[BoxList])
+    """
+    assert isinstance(bboxes, (list, tuple))
+    assert all(isinstance(bbox, BoxList) for bbox in bboxes)
+
+    size = bboxes[0].size
+    assert all(bbox.size == size for bbox in bboxes)
+
+    mode = bboxes[0].mode
+    assert all(bbox.mode == mode for bbox in bboxes)
+
+    fields = set(bboxes[0].fields())
+    assert all(set(bbox.fields()) == fields for bbox in bboxes)
+
+    cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)
+
+    for field in fields:
+        data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0)
+        cat_boxes.add_field(field, data)
+
+    return cat_boxes
+
+
+def getUnionBBox(aBB, bBB, margin = 10):
+    assert aBB.size==bBB.size
+    assert aBB.mode==bBB.mode
+    ih, iw = aBB.size
+    union_boxes = torch.cat([(torch.min(aBB.bbox[:,[0,1]], bBB.bbox[:,[0,1]]) - margin).clamp(min=0), \
+        (torch.max(aBB.bbox[:,[2]], bBB.bbox[:,[2]]) + margin).clamp(max=iw), \
+        (torch.max(aBB.bbox[:,[3]], bBB.bbox[:,[3]]) + margin).clamp(max=ih)], dim=1)
+    return BoxList(union_boxes, aBB.size, mode=aBB.mode)
diff --git a/maskrcnn_benchmark/structures/image_list.py b/maskrcnn_benchmark/structures/image_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..e24df46e95ba39476fdce9f748c0e0f4fb94be98
--- /dev/null
+++ b/maskrcnn_benchmark/structures/image_list.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from __future__ import division
+
+import torch
+
+
+class ImageList(object):
+    """
+    Structure that holds a list of images (of possibly
+    varying sizes) as a single tensor.
+    This works by padding the images to the same size,
+    and storing in a field the original sizes of each image
+    """
+
+    def __init__(self, tensors, image_sizes):
+        """
+        Arguments:
+            tensors (tensor)
+            image_sizes (list[tuple[int, int]])
+        """
+        self.tensors = tensors
+        self.image_sizes = image_sizes
+
+    def to(self, *args, **kwargs):
+        cast_tensor = self.tensors.to(*args, **kwargs)
+        return ImageList(cast_tensor, self.image_sizes)
+
+
+def to_image_list(tensors, size_divisible=0):
+    """
+    tensors can be an ImageList, a torch.Tensor or
+    an iterable of Tensors. It can't be a numpy array.
+    When tensors is an iterable of Tensors, it pads
+    the Tensors with zeros so that they have the same
+    shape
+    """
+    if isinstance(tensors, torch.Tensor) and size_divisible > 0:
+        tensors = [tensors]
+
+    if isinstance(tensors, ImageList):
+        return tensors
+    elif isinstance(tensors, torch.Tensor):
+        # single tensor shape can be inferred
+        assert tensors.dim() == 4
+        image_sizes = [tensor.shape[-2:] for tensor in tensors]
+        return ImageList(tensors, image_sizes)
+    elif isinstance(tensors, (tuple, list)):
+        max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
+
+        # TODO Ideally, just remove this and let me model handle arbitrary
+        # input sizs
+        if size_divisible > 0:
+            import math
+
+            stride = size_divisible
+            max_size = list(max_size)
+            max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
+            max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
+            max_size = tuple(max_size)
+
+        batch_shape = (len(tensors),) + max_size
+        batched_imgs = tensors[0].new(*batch_shape).zero_()
+        for img, pad_img in zip(tensors, batched_imgs):
+            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+        image_sizes = [im.shape[-2:] for im in tensors]
+
+        return ImageList(batched_imgs, image_sizes)
+    else:
+        raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
diff --git a/maskrcnn_benchmark/structures/keypoint.py b/maskrcnn_benchmark/structures/keypoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0d74c536f94bb8ba3f07f435769450c8971244a
--- /dev/null
+++ b/maskrcnn_benchmark/structures/keypoint.py
@@ -0,0 +1,212 @@
+import torch
+from maskrcnn_benchmark.config import cfg
+
+# transpose
+FLIP_LEFT_RIGHT = 0
+FLIP_TOP_BOTTOM = 1
+
+
+class Keypoints(object):
+    def __init__(self, keypoints, size, mode=None):
+        # FIXME remove check once we have better integration with device
+        # in my version this would consistently return a CPU tensor
+        device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device('cpu')
+        keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
+        num_keypoints = keypoints.shape[0]
+        if num_keypoints:
+            keypoints = keypoints.view(num_keypoints, -1, 3)
+
+        # TODO should I split them?
+        # self.visibility = keypoints[..., 2]
+        self.keypoints = keypoints  # [..., :2]
+
+        self.size = size
+        self.mode = mode
+        self.extra_fields = {}
+
+    def crop(self, box):
+        raise NotImplementedError()
+
+    def resize(self, size, *args, **kwargs):
+        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
+        ratio_w, ratio_h = ratios
+        resized_data = self.keypoints.clone()
+        resized_data[..., 0] *= ratio_w
+        resized_data[..., 1] *= ratio_h
+        keypoints = type(self)(resized_data, size, self.mode)
+        for k, v in self.extra_fields.items():
+            keypoints.add_field(k, v)
+        return keypoints
+
+    def transpose(self, method):
+        if method not in (FLIP_LEFT_RIGHT,):
+            raise NotImplementedError(
+                "Only FLIP_LEFT_RIGHT implemented")
+
+        flip_inds = self.FLIP_INDS
+        flipped_data = self.keypoints[:, flip_inds]
+        width = self.size[0]
+        TO_REMOVE = 1
+        # Flip x coordinates
+        flipped_data[..., 0] = width - flipped_data[..., 0] - TO_REMOVE
+
+        # Maintain COCO convention that if visibility == 0, then x, y = 0
+        inds = flipped_data[..., 2] == 0
+        flipped_data[inds] = 0
+
+        keypoints = type(self)(flipped_data, self.size, self.mode)
+        for k, v in self.extra_fields.items():
+            keypoints.add_field(k, v)
+        return keypoints
+
+    def to(self, *args, **kwargs):
+        keypoints = type(self)(self.keypoints.to(*args, **kwargs), self.size, self.mode)
+        for k, v in self.extra_fields.items():
+            if hasattr(v, "to"):
+                v = v.to(*args, **kwargs)
+            keypoints.add_field(k, v)
+        return keypoints
+
+    def __getitem__(self, item):
+        keypoints = type(self)(self.keypoints[item], self.size, self.mode)
+        for k, v in self.extra_fields.items():
+            keypoints.add_field(k, v[item])
+        return keypoints
+
+    def add_field(self, field, field_data):
+        self.extra_fields[field] = field_data
+
+    def get_field(self, field):
+        return self.extra_fields[field]
+
+    def __repr__(self):
+        s = self.__class__.__name__ + '('
+        s += 'num_instances={}, '.format(len(self.keypoints))
+        s += 'image_width={}, '.format(self.size[0])
+        s += 'image_height={})'.format(self.size[1])
+        return s
+
+
+class PersonKeypoints(Keypoints):
+    _NAMES = [
+        'nose',
+        'left_eye',
+        'right_eye',
+        'left_ear',
+        'right_ear',
+        'left_shoulder',
+        'right_shoulder',
+        'left_elbow',
+        'right_elbow',
+        'left_wrist',
+        'right_wrist',
+        'left_hip',
+        'right_hip',
+        'left_knee',
+        'right_knee',
+        'left_ankle',
+        'right_ankle'
+    ]
+    _FLIP_MAP = {
+        'left_eye': 'right_eye',
+        'left_ear': 'right_ear',
+        'left_shoulder': 'right_shoulder',
+        'left_elbow': 'right_elbow',
+        'left_wrist': 'right_wrist',
+        'left_hip': 'right_hip',
+        'left_knee': 'right_knee',
+        'left_ankle': 'right_ankle'
+    }
+
+    def __init__(self, *args, **kwargs):
+        super(PersonKeypoints, self).__init__(*args, **kwargs)
+        if len(cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME)>0:
+            self.NAMES = cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME
+            self.FLIP_MAP = {l:r for l,r in PersonKeypoints._FLIP_MAP.items() if l in cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME}
+        else:
+            self.NAMES = PersonKeypoints._NAMES
+            self.FLIP_MAP = PersonKeypoints._FLIP_MAP
+
+        self.FLIP_INDS = self._create_flip_indices(self.NAMES, self.FLIP_MAP)
+        self.CONNECTIONS = self._kp_connections(self.NAMES)
+
+    def to_coco_format(self):
+        coco_result = []
+        for i in range(self.keypoints.shape[0]):
+            coco_kps = [0]*len(PersonKeypoints._NAMES)*3
+            for ki, name in enumerate(self.NAMES):
+                coco_kps[3*PersonKeypoints._NAMES.index(name)] = self.keypoints[i,ki,0].item()
+                coco_kps[3*PersonKeypoints._NAMES.index(name)+1] = self.keypoints[i,ki,1].item()
+                coco_kps[3*PersonKeypoints._NAMES.index(name)+2] = self.keypoints[i,ki,2].item()
+            coco_result.append(coco_kps)
+        return coco_result
+
+    def _create_flip_indices(self, names, flip_map):
+        full_flip_map = flip_map.copy()
+        full_flip_map.update({v: k for k, v in flip_map.items()})
+        flipped_names = [i if i not in full_flip_map else full_flip_map[i] for i in names]
+        flip_indices = [names.index(i) for i in flipped_names]
+        return torch.tensor(flip_indices)
+
+
+    def _kp_connections(self, keypoints):
+        CONNECTIONS = [
+            ['left_eye', 'right_eye'],
+            ['left_eye', 'nose'],
+            ['right_eye', 'nose'],
+            ['right_eye', 'right_ear'],
+            ['left_eye', 'left_ear'],
+            ['right_shoulder', 'right_elbow'],
+            ['right_elbow', 'right_wrist'],
+            ['left_shoulder', 'left_elbow'],
+            ['left_elbow', 'left_wrist'],
+            ['right_hip', 'right_knee'],
+            ['right_knee', 'right_ankle'],
+            ['left_hip', 'left_knee'],
+            ['left_knee', 'left_ankle'],
+            ['right_shoulder', 'left_shoulder'],
+            ['right_hip', 'left_hip'],
+        ]
+
+        kp_lines = [[keypoints.index(conn[0]), keypoints.index(conn[1])] for conn in CONNECTIONS
+                    if conn[0] in self.NAMES and conn[1] in self.NAMES]
+        return kp_lines
+
+
+
+# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
+def keypoints_to_heat_map(keypoints, rois, heatmap_size):
+    if rois.numel() == 0:
+        return rois.new().long(), rois.new().long()
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
\ No newline at end of file
diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a05e8ff93a352100e8463e074ee888d76e5b451
--- /dev/null
+++ b/maskrcnn_benchmark/structures/segmentation_mask.py
@@ -0,0 +1,214 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+import pycocotools.mask as mask_utils
+
+# transpose
+FLIP_LEFT_RIGHT = 0
+FLIP_TOP_BOTTOM = 1
+
+
+class Mask(object):
+    """
+    This class is unfinished and not meant for use yet
+    It is supposed to contain the mask for an object as
+    a 2d tensor
+    """
+
+    def __init__(self, masks, size, mode):
+        self.masks = masks
+        self.size = size
+        self.mode = mode
+
+    def transpose(self, method):
+        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
+            raise NotImplementedError(
+                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
+            )
+
+        width, height = self.size
+        if method == FLIP_LEFT_RIGHT:
+            dim = width
+            idx = 2
+        elif method == FLIP_TOP_BOTTOM:
+            dim = height
+            idx = 1
+
+        flip_idx = list(range(dim)[::-1])
+        flipped_masks = self.masks.index_select(dim, flip_idx)
+        return Mask(flipped_masks, self.size, self.mode)
+
+    def crop(self, box):
+        w, h = box[2] - box[0], box[3] - box[1]
+
+        cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]]
+        return Mask(cropped_masks, size=(w, h), mode=self.mode)
+
+    def resize(self, size, *args, **kwargs):
+        pass
+
+
+class Polygons(object):
+    """
+    This class holds a set of polygons that represents a single instance
+    of an object mask. The object can be represented as a set of
+    polygons
+    """
+
+    def __init__(self, polygons, size, mode):
+        # assert isinstance(polygons, list), '{}'.format(polygons)
+        if isinstance(polygons, list):
+            polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons]
+        elif isinstance(polygons, Polygons):
+            polygons = polygons.polygons
+
+        self.polygons = polygons
+        self.size = size
+        self.mode = mode
+
+    def transpose(self, method):
+        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
+            raise NotImplementedError(
+                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
+            )
+
+        flipped_polygons = []
+        width, height = self.size
+        if method == FLIP_LEFT_RIGHT:
+            dim = width
+            idx = 0
+        elif method == FLIP_TOP_BOTTOM:
+            dim = height
+            idx = 1
+
+        for poly in self.polygons:
+            p = poly.clone()
+            TO_REMOVE = 1
+            p[idx::2] = dim - poly[idx::2] - TO_REMOVE
+            flipped_polygons.append(p)
+
+        return Polygons(flipped_polygons, size=self.size, mode=self.mode)
+
+    def crop(self, box):
+        w, h = box[2] - box[0], box[3] - box[1]
+
+        # TODO chck if necessary
+        w = max(w, 1)
+        h = max(h, 1)
+
+        cropped_polygons = []
+        for poly in self.polygons:
+            p = poly.clone()
+            p[0::2] = p[0::2] - box[0]  # .clamp(min=0, max=w)
+            p[1::2] = p[1::2] - box[1]  # .clamp(min=0, max=h)
+            cropped_polygons.append(p)
+
+        return Polygons(cropped_polygons, size=(w, h), mode=self.mode)
+
+    def resize(self, size, *args, **kwargs):
+        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
+        if ratios[0] == ratios[1]:
+            ratio = ratios[0]
+            scaled_polys = [p * ratio for p in self.polygons]
+            return Polygons(scaled_polys, size, mode=self.mode)
+
+        ratio_w, ratio_h = ratios
+        scaled_polygons = []
+        for poly in self.polygons:
+            p = poly.clone()
+            p[0::2] *= ratio_w
+            p[1::2] *= ratio_h
+            scaled_polygons.append(p)
+
+        return Polygons(scaled_polygons, size=size, mode=self.mode)
+
+    def convert(self, mode):
+        width, height = self.size
+        if mode == "mask":
+            rles = mask_utils.frPyObjects(
+                [p.detach().numpy() for p in self.polygons], height, width
+            )
+            rle = mask_utils.merge(rles)
+            mask = mask_utils.decode(rle)
+            mask = torch.from_numpy(mask)
+            # TODO add squeeze?
+            return mask
+
+    def __repr__(self):
+        s = self.__class__.__name__ + "("
+        s += "num_polygons={}, ".format(len(self.polygons))
+        s += "image_width={}, ".format(self.size[0])
+        s += "image_height={}, ".format(self.size[1])
+        s += "mode={})".format(self.mode)
+        return s
+
+
+class SegmentationMask(object):
+    """
+    This class stores the segmentations for all objects in the image
+    """
+
+    def __init__(self, polygons, size, mode=None):
+        """
+        Arguments:
+            polygons: a list of list of lists of numbers. The first
+                level of the list correspond to individual instances,
+                the second level to all the polygons that compose the
+                object, and the third level to the polygon coordinates.
+        """
+        assert isinstance(polygons, list)
+
+        self.polygons = [Polygons(p, size, mode) for p in polygons]
+        self.size = size
+        self.mode = mode
+
+    def transpose(self, method):
+        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
+            raise NotImplementedError(
+                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
+            )
+
+        flipped = []
+        for polygon in self.polygons:
+            flipped.append(polygon.transpose(method))
+        return SegmentationMask(flipped, size=self.size, mode=self.mode)
+
+    def crop(self, box):
+        w, h = box[2] - box[0], box[3] - box[1]
+        cropped = []
+        for polygon in self.polygons:
+            cropped.append(polygon.crop(box))
+        return SegmentationMask(cropped, size=(w, h), mode=self.mode)
+
+    def resize(self, size, *args, **kwargs):
+        scaled = []
+        for polygon in self.polygons:
+            scaled.append(polygon.resize(size, *args, **kwargs))
+        return SegmentationMask(scaled, size=size, mode=self.mode)
+
+    def to(self, *args, **kwargs):
+        return self
+
+    def __getitem__(self, item):
+        if isinstance(item, (int, slice)):
+            selected_polygons = [self.polygons[item]]
+        else:
+            # advanced indexing on a single dimension
+            selected_polygons = []
+            if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
+                item = item.nonzero()
+                item = item.squeeze(1) if item.numel() > 0 else item
+                item = item.tolist()
+            for i in item:
+                selected_polygons.append(self.polygons[i])
+        return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
+
+    def __iter__(self):
+        return iter(self.polygons)
+
+    def __repr__(self):
+        s = self.__class__.__name__ + "("
+        s += "num_instances={}, ".format(len(self.polygons))
+        s += "image_width={}, ".format(self.size[0])
+        s += "image_height={})".format(self.size[1])
+        return s
diff --git a/maskrcnn_benchmark/utils/README.md b/maskrcnn_benchmark/utils/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3c35e560d1b3e3fb6cfc5e5a5653a283b1c603e3
--- /dev/null
+++ b/maskrcnn_benchmark/utils/README.md
@@ -0,0 +1,5 @@
+# Utility functions
+
+This folder contain utility functions that are not used in the
+core library, but are useful for building models or training
+code using the config system.
diff --git a/maskrcnn_benchmark/utils/__init__.py b/maskrcnn_benchmark/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/maskrcnn_benchmark/utils/amp.py b/maskrcnn_benchmark/utils/amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b1a4f5bd5baaf888829aca231af445b4600650
--- /dev/null
+++ b/maskrcnn_benchmark/utils/amp.py
@@ -0,0 +1,14 @@
+from contextlib import contextmanager
+
+@contextmanager
+def nullcontext(enter_result=None, **kwargs):
+    yield enter_result
+
+try:
+    from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
+except:
+    print('[Warning] Library for automatic mixed precision is not found, AMP is disabled!!')
+    GradScaler = nullcontext
+    autocast = nullcontext
+    custom_fwd = nullcontext
+    custom_bwd = nullcontext
\ No newline at end of file
diff --git a/maskrcnn_benchmark/utils/big_model_loading.py b/maskrcnn_benchmark/utils/big_model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..25dc5429f2b771a96edd402c569bf140dac7fc33
--- /dev/null
+++ b/maskrcnn_benchmark/utils/big_model_loading.py
@@ -0,0 +1,80 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from collections import OrderedDict
+
+
+def tf2th(conv_weights):
+    """Possibly convert HWIO to OIHW."""
+    if conv_weights.ndim == 4:
+        conv_weights = conv_weights.transpose([3, 2, 0, 1])
+    return torch.from_numpy(conv_weights)
+
+
+def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
+    import re
+    layer_keys = sorted(state_dict.keys())
+    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
+        if not stage_with_dcn:
+            continue
+        for old_key in layer_keys:
+            pattern = ".*block{}.*conv2.*".format(ix)
+            r = re.match(pattern, old_key)
+            if r is None:
+                continue
+            for param in ["weight", "bias"]:
+                if old_key.find(param) is -1:
+                    continue
+                if 'unit01' in old_key:
+                    continue
+                new_key = old_key.replace(
+                    "conv2.{}".format(param), "conv2.conv.{}".format(param)
+                )
+                print("pattern: {}, old_key: {}, new_key: {}".format(
+                    pattern, old_key, new_key
+                ))
+                # Calculate SD conv weight
+                w = state_dict[old_key]
+                v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
+                w = (w - m) / torch.sqrt(v + 1e-10)
+
+                state_dict[new_key] = w
+                del state_dict[old_key]
+    return state_dict
+
+
+def load_big_format(cfg, f):
+    model = OrderedDict()
+    weights = np.load(f)
+
+    cmap = {'a':1, 'b':2, 'c':3}
+    for key, val in weights.items():
+        old_key = key.replace('resnet/', '')
+        if 'root_block' in old_key:
+            new_key = 'root.conv.weight'
+        elif '/proj/standardized_conv2d/kernel' in old_key:
+            key_pattern = old_key.replace('/proj/standardized_conv2d/kernel', '').replace('resnet/', '')
+            bname, uname, cidx = key_pattern.split('/')
+            new_key = '{}.downsample.{}.conv{}.weight'.format(bname,uname,cmap[cidx])
+        elif '/standardized_conv2d/kernel' in old_key:
+            key_pattern = old_key.replace('/standardized_conv2d/kernel', '').replace('resnet/', '')
+            bname, uname, cidx = key_pattern.split('/')
+            new_key = '{}.{}.conv{}.weight'.format(bname,uname,cmap[cidx])
+        elif '/group_norm/gamma' in old_key:
+            key_pattern = old_key.replace('/group_norm/gamma', '').replace('resnet/', '')
+            bname, uname, cidx = key_pattern.split('/')
+            new_key = '{}.{}.gn{}.weight'.format(bname,uname,cmap[cidx])
+        elif '/group_norm/beta' in old_key:
+            key_pattern = old_key.replace('/group_norm/beta', '').replace('resnet/', '')
+            bname, uname, cidx = key_pattern.split('/')
+            new_key = '{}.{}.gn{}.bias'.format(bname,uname,cmap[cidx])
+        else:
+            print('Unknown key {}'.format(old_key))
+            continue
+        print('Map {} -> {}'.format(key, new_key))
+        model[new_key] = tf2th(val)
+
+    model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)
+
+    return dict(model=model)
diff --git a/maskrcnn_benchmark/utils/c2_model_loading.py b/maskrcnn_benchmark/utils/c2_model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51eea3a16aba9d1f392ac10a1602b1023938c30
--- /dev/null
+++ b/maskrcnn_benchmark/utils/c2_model_loading.py
@@ -0,0 +1,207 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import pickle
+from collections import OrderedDict
+
+import torch
+
+from maskrcnn_benchmark.utils.model_serialization import load_state_dict
+from maskrcnn_benchmark.utils.registry import Registry
+
+
+def _rename_basic_resnet_weights(layer_keys):
+    layer_keys = [k.replace("_", ".") for k in layer_keys]
+    layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
+    layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
+    layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
+    layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
+    layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
+    layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys]
+    layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys]
+    layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys]
+
+    # RPN / Faster RCNN
+    layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys]
+    layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys]
+    layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys]
+    layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys]
+
+    # Affine-Channel -> BatchNorm enaming
+    layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]
+
+    # Make torchvision-compatible
+    layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys]
+
+    layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys]
+    layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys]
+    layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys]
+    layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys]
+
+    layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
+    layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys]
+    layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
+    layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys]
+    layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
+    layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys]
+
+    layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
+    layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]
+
+    # GroupNorm
+    layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys]
+    layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys]
+    layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys]
+    layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys]
+    layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys]
+    layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys]
+    layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \
+        for k in layer_keys]
+    layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
+        for k in layer_keys]
+
+    return layer_keys
+
+def _rename_fpn_weights(layer_keys, stage_names):
+    for mapped_idx, stage_name in enumerate(stage_names, 1):
+        suffix = ""
+        if mapped_idx < 4:
+            suffix = ".lateral"
+        layer_keys = [
+            k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys
+        ]
+        layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys]
+
+
+    layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys]
+    layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys]
+    layer_keys = [
+        k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys
+    ]
+
+    return layer_keys
+
+
+def _rename_weights_for_resnet(weights, stage_names):
+    original_keys = sorted(weights.keys())
+    layer_keys = sorted(weights.keys())
+
+    # for X-101, rename output to fc1000 to avoid conflicts afterwards
+    layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys]
+    layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys]
+
+    # performs basic renaming: _ -> . , etc
+    layer_keys = _rename_basic_resnet_weights(layer_keys)
+
+    # FPN
+    layer_keys = _rename_fpn_weights(layer_keys, stage_names)
+
+    # Mask R-CNN
+    layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys]
+    layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys]
+    layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys]
+
+    # Keypoint R-CNN
+    layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys]
+    layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys]
+    layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys]
+
+    # Rename for our RPN structure
+    layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys]
+
+    key_map = {k: v for k, v in zip(original_keys, layer_keys)}
+
+    logger = logging.getLogger(__name__)
+    logger.info("Remapping C2 weights")
+    max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k])
+
+    new_weights = OrderedDict()
+    for k in original_keys:
+        v = weights[k]
+        if "_momentum" in k:
+            continue
+        if 'weight_order' in k:
+            continue
+        # if 'fc1000' in k:
+        #     continue
+        w = torch.from_numpy(v)
+        # if "bn" in k:
+        #     w = w.view(1, -1, 1, 1)
+        logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k]))
+        new_weights[key_map[k]] = w
+
+    return new_weights
+
+
+def _load_c2_pickled_weights(file_path):
+    with open(file_path, "rb") as f:
+        if torch._six.PY3:
+            data = pickle.load(f, encoding="latin1")
+        else:
+            data = pickle.load(f)
+    if "blobs" in data:
+        weights = data["blobs"]
+    else:
+        weights = data
+    return weights
+
+
+def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
+    import re
+    logger = logging.getLogger(__name__)
+    logger.info("Remapping conv weights for deformable conv weights")
+    layer_keys = sorted(state_dict.keys())
+    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
+        if not stage_with_dcn:
+            continue
+        for old_key in layer_keys:
+            pattern = ".*layer{}.*conv2.*".format(ix)
+            r = re.match(pattern, old_key)
+            if r is None:
+                continue
+            for param in ["weight", "bias"]:
+                if old_key.find(param) is -1:
+                    continue
+                new_key = old_key.replace(
+                    "conv2.{}".format(param), "conv2.conv.{}".format(param)
+                )
+                logger.info("pattern: {}, old_key: {}, new_key: {}".format(
+                    pattern, old_key, new_key
+                ))
+                state_dict[new_key] = state_dict[old_key]
+                del state_dict[old_key]
+    return state_dict
+
+
+_C2_STAGE_NAMES = {
+    "R-50": ["1.2", "2.3", "3.5", "4.2"],
+    "R-101": ["1.2", "2.3", "3.22", "4.2"],
+}
+
+C2_FORMAT_LOADER = Registry()
+
+
+@C2_FORMAT_LOADER.register("R-50-C4")
+@C2_FORMAT_LOADER.register("R-50-C5")
+@C2_FORMAT_LOADER.register("R-101-C4")
+@C2_FORMAT_LOADER.register("R-101-C5")
+@C2_FORMAT_LOADER.register("R-50-FPN")
+@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
+@C2_FORMAT_LOADER.register("R-50-FPN-FCOS")
+@C2_FORMAT_LOADER.register("R-101-FPN")
+@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
+@C2_FORMAT_LOADER.register("R-101-FPN-FCOS")
+def load_resnet_c2_format(cfg, f):
+    state_dict = _load_c2_pickled_weights(f)
+    conv_body = cfg.MODEL.BACKBONE.CONV_BODY
+    arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "").replace("-RETINANET", "").replace("-FCOS", "")
+    stages = _C2_STAGE_NAMES[arch]
+    state_dict = _rename_weights_for_resnet(state_dict, stages)
+    # ***********************************
+    # for deformable convolutional layer
+    state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
+    # ***********************************
+    return dict(model=state_dict)
+
+
+def load_c2_format(cfg, f):
+    return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
diff --git a/maskrcnn_benchmark/utils/checkpoint.py b/maskrcnn_benchmark/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..c35a8f6e7b6da8478d0100f8c240e5ee1d50ccba
--- /dev/null
+++ b/maskrcnn_benchmark/utils/checkpoint.py
@@ -0,0 +1,163 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import os
+
+import torch
+
+from maskrcnn_benchmark.utils.model_serialization import load_state_dict
+from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format
+from maskrcnn_benchmark.utils.big_model_loading import load_big_format
+from maskrcnn_benchmark.utils.pretrain_model_loading import load_pretrain_format
+from maskrcnn_benchmark.utils.imports import import_file
+from maskrcnn_benchmark.utils.model_zoo import cache_url
+
+
+class Checkpointer(object):
+    def __init__(
+        self,
+        model,
+        optimizer=None,
+        scheduler=None,
+        save_dir="",
+        save_to_disk=None,
+        logger=None,
+    ):
+        self.model = model
+        self.optimizer = optimizer
+        self.scheduler = scheduler
+        self.save_dir = save_dir
+        self.save_to_disk = save_to_disk
+        if logger is None:
+            logger = logging.getLogger(__name__)
+        self.logger = logger
+
+    def save(self, name, **kwargs):
+        if not self.save_dir:
+            return
+
+        if not self.save_to_disk:
+            return
+
+        data = {}
+        data["model"] = self.model.state_dict()
+        if self.optimizer is not None:
+            data["optimizer"] = self.optimizer.state_dict()
+        if self.scheduler is not None:
+            if isinstance(self.scheduler, list):
+                data["scheduler"] = [scheduler.state_dict() for scheduler in self.scheduler]
+            else:
+                data["scheduler"] = self.scheduler.state_dict()
+        data.update(kwargs)
+
+        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
+        self.logger.info("Saving checkpoint to {}".format(save_file))
+        torch.save(data, save_file)
+        # self.tag_last_checkpoint(save_file)
+        # use relative path name to save the checkpoint
+        self.tag_last_checkpoint("{}.pth".format(name))
+
+    def load(self, f=None, force=False, keyword="model", skip_optimizer =False):
+        resume = False
+        if self.has_checkpoint() and not force:
+            # override argument with existing checkpoint
+            f = self.get_checkpoint_file()
+            # get the absolute path
+            f = os.path.join(self.save_dir, f)
+            resume = True
+        if not f:
+            # no checkpoint could be found
+            self.logger.info("No checkpoint found. Initializing model from scratch")
+            return {}
+        self.logger.info("Loading checkpoint from {}".format(f))
+        checkpoint = self._load_file(f)
+        self._load_model(checkpoint, keyword=keyword)
+        # if resume training, load optimizer and scheduler,
+        # otherwise use the specified LR in config yaml for fine-tuning
+        if resume and not skip_optimizer:
+            if "optimizer" in checkpoint and self.optimizer:
+                self.logger.info("Loading optimizer from {}".format(f))
+                self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
+            if "scheduler" in checkpoint and self.scheduler:
+                self.logger.info("Loading scheduler from {}".format(f))
+                if isinstance(self.scheduler, list):
+                    for scheduler, state_dict in zip(self.scheduler, checkpoint.pop("scheduler")):
+                        scheduler.load_state_dict(state_dict)
+                else:
+                    self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
+
+            # return any further checkpoint data
+            return checkpoint
+        else:
+            return {}
+
+    def has_checkpoint(self):
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        return os.path.exists(save_file)
+
+    def get_checkpoint_file(self):
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        try:
+            with open(save_file, "r") as f:
+                last_saved = f.read()
+                last_saved = last_saved.strip()
+        except IOError:
+            # if file doesn't exist, maybe because it has just been
+            # deleted by a separate process
+            last_saved = ""
+        return last_saved
+
+    def tag_last_checkpoint(self, last_filename):
+        save_file = os.path.join(self.save_dir, "last_checkpoint")
+        with open(save_file, "w") as f:
+            f.write(last_filename)
+
+    def _load_file(self, f):
+        return torch.load(f, map_location=torch.device("cpu"))
+
+    def _load_model(self, checkpoint, keyword="model"):
+        load_state_dict(self.model, checkpoint.pop(keyword))
+
+
+class DetectronCheckpointer(Checkpointer):
+    def __init__(
+        self,
+        cfg,
+        model,
+        optimizer=None,
+        scheduler=None,
+        save_dir="",
+        save_to_disk=None,
+        logger=None,
+    ):
+        super(DetectronCheckpointer, self).__init__(
+            model, optimizer, scheduler, save_dir, save_to_disk, logger
+        )
+        self.cfg = cfg.clone()
+
+    def _load_file(self, f):
+        # catalog lookup
+        if f.startswith("catalog://"):
+            paths_catalog = import_file(
+                "maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True
+            )
+            catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :])
+            self.logger.info("{} points to {}".format(f, catalog_f))
+            f = catalog_f
+        # download url files
+        if f.startswith("http"):
+            # if the file is a url path, download it and cache it
+            cached_f = cache_url(f)
+            self.logger.info("url {} cached in {}".format(f, cached_f))
+            f = cached_f
+        # convert Caffe2 checkpoint from pkl
+        if f.endswith(".pkl"):
+            return load_c2_format(self.cfg, f)
+        if f.endswith(".big"):
+            return load_big_format(self.cfg, f)
+        if f.endswith(".pretrain"):
+            return load_pretrain_format(self.cfg, f)
+        # load native detectron.pytorch checkpoint
+        loaded = super(DetectronCheckpointer, self)._load_file(f)
+        if "model" not in loaded:
+            loaded = dict(model=loaded)
+        return loaded
diff --git a/maskrcnn_benchmark/utils/collect_env.py b/maskrcnn_benchmark/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..d93d6164aed31b783c58581cc85c183e1f1805be
--- /dev/null
+++ b/maskrcnn_benchmark/utils/collect_env.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import PIL
+
+from torch.utils.collect_env import get_pretty_env_info
+
+
+def get_pil_version():
+    return "\n        Pillow ({})".format(PIL.__version__)
+
+
+def collect_env_info():
+    env_str = get_pretty_env_info()
+    env_str += get_pil_version()
+    return env_str
diff --git a/maskrcnn_benchmark/utils/comm.py b/maskrcnn_benchmark/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1222d2b36d83edb659973cf2253e4d5201d823c
--- /dev/null
+++ b/maskrcnn_benchmark/utils/comm.py
@@ -0,0 +1,157 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import pickle
+import time
+import functools
+import logging
+import torch
+import torch.distributed as dist
+import numpy as np
+
+
+def get_world_size():
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def synchronize():
+    """
+    Helper function to synchronize (barrier) among all processes when
+    using distributed training
+    """
+    if not dist.is_available():
+        return
+    if not dist.is_initialized():
+        return
+    world_size = dist.get_world_size()
+    if world_size == 1:
+        return
+    dist.barrier()
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    # serialized to a Tensor
+    buffer = pickle.dumps(data)
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to("cuda")
+
+    # obtain Tensor size of each rank
+    local_size = torch.LongTensor([tensor.numel()]).to("cuda")
+    size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
+    dist.all_gather(size_list, local_size)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
+    if local_size != max_size:
+        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
+        tensor = torch.cat((tensor, padding), dim=0)
+    dist.all_gather(tensor_list, tensor)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that process with rank
+    0 has the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.reduce(values, dst=0)
+        if dist.get_rank() == 0 and average:
+            # only main process gets accumulated, so only divide by
+            # world_size in this case
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+def broadcast_data(data):
+    if not torch.distributed.is_initialized():
+        return data
+    rank = dist.get_rank()
+    if rank == 0:
+        data_tensor = torch.tensor(data + [0], device="cuda")
+    else:
+        data_tensor = torch.tensor(data + [1], device="cuda")
+    torch.distributed.broadcast(data_tensor, 0)
+    while data_tensor.cpu().numpy()[-1] == 1:
+        time.sleep(1)
+
+    return data_tensor.cpu().numpy().tolist()[:-1]
+
+
+def reduce_sum(tensor):
+    if get_world_size() <= 1:
+        return tensor
+
+    tensor = tensor.clone()
+    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+    return tensor
+
+
+def shared_random_seed():
+    """
+    Returns:
+        int: a random number that is the same across all workers.
+            If workers need a shared RNG, they can use this shared seed to
+            create one.
+
+    All workers must call this function, otherwise it will deadlock.
+    """
+    ints = np.random.randint(2 ** 31)
+    all_ints = all_gather(ints)
+    return all_ints[0]
\ No newline at end of file
diff --git a/maskrcnn_benchmark/utils/cv2_util.py b/maskrcnn_benchmark/utils/cv2_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..268db9e5be1dc8094c39a1fddc1bfd7a89a7ca47
--- /dev/null
+++ b/maskrcnn_benchmark/utils/cv2_util.py
@@ -0,0 +1,24 @@
+"""
+Module for cv2 utility functions and maintaining version compatibility
+between 3.x and 4.x
+"""
+import cv2
+
+
+def findContours(*args, **kwargs):
+    """
+    Wraps cv2.findContours to maintain compatiblity between versions
+    3 and 4
+
+    Returns:
+        contours, hierarchy
+    """
+    if cv2.__version__.startswith('4'):
+        contours, hierarchy = cv2.findContours(*args, **kwargs)
+    elif cv2.__version__.startswith('3'):
+        _, contours, hierarchy = cv2.findContours(*args, **kwargs)
+    else:
+        raise AssertionError(
+            'cv2 must be either version 3 or 4 to call this method')
+
+    return contours, hierarchy
diff --git a/maskrcnn_benchmark/utils/dist.py b/maskrcnn_benchmark/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..de7ac00c0eed1acc723df95f79367af82f79ddb0
--- /dev/null
+++ b/maskrcnn_benchmark/utils/dist.py
@@ -0,0 +1,228 @@
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities related to distributed mode.
+
+By default, the reduce of metrics and such are done on GPU, since it's more straightforward (we reuse the NCCL backend)
+If you want to reduce on CPU instead (required for big datasets like GQA), use the env variable MDETR_CPU_REDUCE=1
+"""
+import functools
+import io
+import os
+
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+    """
+    Return a process group based on gloo backend, containing all the ranks
+    The result is cached.
+    """
+
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+
+    return dist.group.WORLD
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    cpu_group = None
+    if os.getenv("MDETR_CPU_REDUCE") == "1":
+        cpu_group = _get_global_gloo_group()
+
+    buffer = io.BytesIO()
+    torch.save(data, buffer)
+    data_view = buffer.getbuffer()
+    device = "cuda" if cpu_group is None else "cpu"
+    tensor = torch.ByteTensor(data_view).to(device)
+
+    # obtain Tensor size of each rank
+    local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
+    size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
+    if cpu_group is None:
+        dist.all_gather(size_list, local_size)
+    else:
+        print("gathering on cpu")
+        dist.all_gather(size_list, local_size, group=cpu_group)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+    assert isinstance(local_size.item(), int)
+    local_size = int(local_size.item())
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
+    if local_size != max_size:
+        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
+        tensor = torch.cat((tensor, padding), dim=0)
+    if cpu_group is None:
+        dist.all_gather(tensor_list, tensor)
+    else:
+        dist.all_gather(tensor_list, tensor, group=cpu_group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
+        buffer = io.BytesIO(tensor.cpu().numpy())
+        obj = torch.load(buffer)
+        data_list.append(obj)
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    """
+    Returns:
+        True if distributed training is enabled
+    """
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    """
+    Returns:
+        The number of processes in the process group
+    """
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    """
+    Returns:
+        The rank of the current process within the global process group.
+    """
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    """
+    Returns:
+        The rank of the current process within the local (per-machine) process group.
+    """
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    assert _LOCAL_PROCESS_GROUP is not None
+    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+    """
+    Returns:
+        The size of the per-machine process group,
+        i.e. the number of processes per machine.
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process():
+    """Return true if the current process is the main one"""
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    """Utility function to save only from the main process"""
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    """Initialize distributed training, if appropriate"""
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
+
+    dist.init_process_group(
+        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+    )
+    dist.barrier()
+    setup_for_distributed(args.rank == 0)
diff --git a/maskrcnn_benchmark/utils/ema.py b/maskrcnn_benchmark/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..1da65bc07a0365bf950aac5232ccbe666ae85741
--- /dev/null
+++ b/maskrcnn_benchmark/utils/ema.py
@@ -0,0 +1,46 @@
+from copy import deepcopy
+from collections import OrderedDict
+import torch
+
+
+class ModelEma:
+    def __init__(self, model, decay=0.9999, device=''):
+        self.ema = deepcopy(model)
+        self.ema.eval()
+        self.decay = decay
+        self.device = device
+        if device:
+            self.ema.to(device=device)
+        self.ema_is_dp = hasattr(self.ema, 'module')
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def load_checkpoint(self, checkpoint):
+        if isinstance(checkpoint, str):
+            checkpoint = torch.load(checkpoint)
+
+        assert isinstance(checkpoint, dict)
+        if 'model_ema' in checkpoint:
+            new_state_dict = OrderedDict()
+            for k, v in checkpoint['model_ema'].items():
+                if self.ema_is_dp:
+                    name = k if k.startswith('module') else 'module.' + k
+                else:
+                    name = k.replace('module.', '') if k.startswith('module') else k
+                new_state_dict[name] = v
+            self.ema.load_state_dict(new_state_dict)
+
+    def state_dict(self):
+        return self.ema.state_dict()
+
+    def update(self, model):
+        pre_module = hasattr(model, 'module') and not self.ema_is_dp
+        with torch.no_grad():
+            curr_msd = model.state_dict()
+            for k, ema_v in self.ema.state_dict().items():
+                k = 'module.' + k if pre_module else k
+                model_v = curr_msd[k].detach()
+                if self.device:
+                    model_v = model_v.to(device=self.device)
+                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
+
diff --git a/maskrcnn_benchmark/utils/env.py b/maskrcnn_benchmark/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3e19c760c076c3dfdb89cf2bf34b7ed8866a019
--- /dev/null
+++ b/maskrcnn_benchmark/utils/env.py
@@ -0,0 +1,37 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
+
+from maskrcnn_benchmark.utils.imports import import_file
+
+
+def setup_environment():
+    """Perform environment setup work. The default setup is a no-op, but this
+    function allows the user to specify a Python source file that performs
+    custom setup work that may be necessary to their computing environment.
+    """
+    custom_module_path = os.environ.get("TORCH_DETECTRON_ENV_MODULE")
+    if custom_module_path:
+        setup_custom_environment(custom_module_path)
+    else:
+        # The default setup is a no-op
+        pass
+
+
+def setup_custom_environment(custom_module_path):
+    """Load custom environment setup from a Python source file and run the setup
+    function.
+    """
+    module = import_file("maskrcnn_benchmark.utils.env.custom_module", custom_module_path)
+    assert hasattr(module, "setup_environment") and callable(
+        module.setup_environment
+    ), (
+        "Custom environment module defined in {} does not have the "
+        "required callable attribute 'setup_environment'."
+    ).format(
+        custom_module_path
+    )
+    module.setup_environment()
+
+
+# Force environment setup when this module is imported
+setup_environment()
diff --git a/maskrcnn_benchmark/utils/flops.py b/maskrcnn_benchmark/utils/flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2e3d72ff32c3d2824099517356067ad55c722a2
--- /dev/null
+++ b/maskrcnn_benchmark/utils/flops.py
@@ -0,0 +1,249 @@
+import argparse
+import logging
+import torch
+import torch.nn as nn
+import timeit
+
+from maskrcnn_benchmark.layers import *
+from maskrcnn_benchmark.modeling.backbone.resnet_big import StdConv2d
+from maskrcnn_benchmark.modeling.backbone.fpn import *
+from maskrcnn_benchmark.modeling.rpn.inference import *
+from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import PostProcessor
+from maskrcnn_benchmark.modeling.rpn.anchor_generator import BufferList
+
+
+def profile(model, input_size, custom_ops={}, device="cpu", verbose=False, extra_args={}, return_time=False):
+    handler_collection = []
+
+    def add_hooks(m):
+        if len(list(m.children())) > 0:
+            return
+
+        m.register_buffer('total_ops', torch.zeros(1))
+        m.register_buffer('total_params', torch.zeros(1))
+
+        for p in m.parameters():
+            m.total_params += torch.Tensor([p.numel()])
+
+        m_type = type(m)
+        fn = None
+
+        if m_type in custom_ops:
+            fn = custom_ops[m_type]
+        elif m_type in register_hooks:
+            fn = register_hooks[m_type]
+        else:
+            print("Not implemented for ", m)
+
+        if fn is not None:
+            if verbose:
+                print("Register FLOP counter for module %s" % str(m))
+            handler = m.register_forward_hook(fn)
+            handler_collection.append(handler)
+
+    original_device = model.parameters().__next__().device
+    training = model.training
+
+    model.eval().to(device)
+    model.apply(add_hooks)
+
+    x = torch.zeros(input_size).to(device)
+    with torch.no_grad():
+        tic = timeit.time.perf_counter()
+        model(x, **extra_args)
+        toc = timeit.time.perf_counter()
+        total_time = toc-tic
+
+    total_ops = 0
+    total_params = 0
+    for m in model.modules():
+        if len(list(m.children())) > 0:  # skip for non-leaf module
+            continue
+        total_ops += m.total_ops
+        total_params += m.total_params
+
+    total_ops = total_ops.item()
+    total_params = total_params.item()
+
+    model.train(training).to(original_device)
+    for handler in handler_collection:
+        handler.remove()
+
+    if return_time:
+        return total_ops, total_params, total_time
+    else:
+        return total_ops, total_params
+
+
+multiply_adds = 1
+def count_conv2d(m, x, y):
+    x = x[0]
+    cin = m.in_channels
+    cout = m.out_channels
+    kh, kw = m.kernel_size
+    batch_size = x.size()[0]
+    out_h = y.size(2)
+    out_w = y.size(3)
+    # ops per output element
+    # kernel_mul = kh * kw * cin
+    # kernel_add = kh * kw * cin - 1
+    kernel_ops = multiply_adds * kh * kw * cin // m.groups
+    bias_ops = 1 if m.bias is not None else 0
+    ops_per_element = kernel_ops + bias_ops
+    # total ops
+    # num_out_elements = y.numel()
+    output_elements = batch_size * out_w * out_h * cout
+    total_ops = output_elements * ops_per_element
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_convtranspose2d(m, x, y):
+    x = x[0]
+    cin = m.in_channels
+    cout = m.out_channels
+    kh, kw = m.kernel_size
+    batch_size = x.size()[0]
+    out_h = y.size(2)
+    out_w = y.size(3)
+    # ops per output element
+    # kernel_mul = kh * kw * cin
+    # kernel_add = kh * kw * cin - 1
+    kernel_ops = multiply_adds * kh * kw * cin // m.groups
+    bias_ops = 1 if m.bias is not None else 0
+    ops_per_element = kernel_ops + bias_ops
+    # total ops
+    # num_out_elements = y.numel()
+    # output_elements = batch_size * out_w * out_h * cout
+    ops_per_element = m.weight.nelement()
+    output_elements = y.nelement()
+    total_ops = output_elements * ops_per_element
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_bn(m, x, y):
+    x = x[0]
+    nelements = x.numel()
+    # subtract, divide, gamma, beta
+    total_ops = 4*nelements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_relu(m, x, y):
+    x = x[0]
+    nelements = x.numel()
+    total_ops = nelements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_softmax(m, x, y):
+    x = x[0]
+    batch_size, nfeatures = x.size()
+    total_exp = nfeatures
+    total_add = nfeatures - 1
+    total_div = nfeatures
+    total_ops = batch_size * (total_exp + total_add + total_div)
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_maxpool(m, x, y):
+    kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
+    num_elements = y.numel()
+    total_ops = kernel_ops * num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_adap_maxpool(m, x, y):
+    kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze()
+    kernel_ops = torch.prod(kernel)
+    num_elements = y.numel()
+    total_ops = kernel_ops * num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_avgpool(m, x, y):
+    total_add = torch.prod(torch.Tensor([m.kernel_size]))
+    total_div = 1
+    kernel_ops = total_add + total_div
+    num_elements = y.numel()
+    total_ops = kernel_ops * num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_adap_avgpool(m, x, y):
+    kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze()
+    total_add = torch.prod(kernel)
+    total_div = 1
+    kernel_ops = total_add + total_div
+    num_elements = y.numel()
+    total_ops = kernel_ops * num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_linear(m, x, y):
+    # per output element
+    total_mul = m.in_features
+    total_add = m.in_features - 1
+    num_elements = y.numel()
+    total_ops = (total_mul + total_add) * num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_LastLevelMaxPool(m, x, y):
+    num_elements = y[-1].numel()
+    total_ops = num_elements
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+def count_ROIAlign(m, x, y):
+    num_elements = y.numel()
+    total_ops = num_elements*4
+    m.total_ops = torch.Tensor([int(total_ops)])
+
+
+register_hooks = {
+    Scale: None,
+    Conv2d: count_conv2d,
+    nn.Conv2d: count_conv2d,
+    ModulatedDeformConv: count_conv2d,
+    StdConv2d: count_conv2d,
+
+    nn.BatchNorm1d: count_bn,
+    nn.BatchNorm2d: count_bn,
+    nn.BatchNorm3d: count_bn,
+    FrozenBatchNorm2d: count_bn,
+    nn.GroupNorm: count_bn,
+    NaiveSyncBatchNorm2d: count_bn,
+
+    nn.ReLU: count_relu,
+    nn.ReLU6: count_relu,
+    swish: None,
+
+    nn.ConstantPad2d: None,
+    SPPLayer: count_LastLevelMaxPool,
+    LastLevelMaxPool: count_LastLevelMaxPool,
+    nn.MaxPool1d: count_maxpool,
+    nn.MaxPool2d: count_maxpool,
+    nn.MaxPool3d: count_maxpool,
+    nn.AdaptiveMaxPool1d: count_adap_maxpool,
+    nn.AdaptiveMaxPool2d: count_adap_maxpool,
+    nn.AdaptiveMaxPool3d: count_adap_maxpool,
+    nn.AvgPool1d: count_avgpool,
+    nn.AvgPool2d: count_avgpool,
+    nn.AvgPool3d: count_avgpool,
+    nn.AdaptiveAvgPool1d: count_adap_avgpool,
+    nn.AdaptiveAvgPool2d: count_adap_avgpool,
+    nn.AdaptiveAvgPool3d: count_adap_avgpool,
+    nn.Linear: count_linear,
+    nn.Upsample: None,
+    nn.Dropout: None,
+    nn.Sigmoid: None,
+    DropBlock2D: None,
+
+    ROIAlign: count_ROIAlign,
+    RPNPostProcessor: None,
+    PostProcessor: None,
+    BufferList: None,
+    RetinaPostProcessor: None,
+    FCOSPostProcessor: None,
+    ATSSPostProcessor: None,
+}
\ No newline at end of file
diff --git a/maskrcnn_benchmark/utils/fuse_helper.py b/maskrcnn_benchmark/utils/fuse_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9ea03f9f69c9d1a4f9a49c90436d540dc612e5
--- /dev/null
+++ b/maskrcnn_benchmark/utils/fuse_helper.py
@@ -0,0 +1,608 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import pdb
+import math
+from maskrcnn_benchmark.modeling.utils import cat, concat_box_prediction_layers, permute_and_flatten
+from timm.models.layers import DropPath
+
+from transformers.activations import ACT2FN
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+class FeatureResizer(nn.Module):
+    """
+    This class takes as input a set of embeddings of dimension C1 and outputs a set of
+    embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
+    """
+
+    def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
+        super().__init__()
+        self.do_ln = do_ln
+        # Object feature encoding
+        self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
+        self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, encoder_features):
+        x = self.fc(encoder_features)
+        if self.do_ln:
+            x = self.layer_norm(x)
+        output = self.dropout(x)
+        return output
+
+
+def _make_conv(input_dim, output_dim, k, stride=1):
+    pad = (k - 1) // 2
+    return nn.Sequential(
+        nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)),
+        nn.BatchNorm2d(output_dim),
+        nn.ReLU(inplace=True)
+    )
+
+
+def _make_mlp(input_dim, output_dim, drop):
+    return nn.Sequential(nn.Linear(input_dim, output_dim),
+                         nn.BatchNorm1d(output_dim),
+                         nn.ReLU(inplace=True),
+                         nn.Dropout(drop),
+                         nn.Linear(output_dim, output_dim),
+                         nn.BatchNorm1d(output_dim),
+                         nn.ReLU(inplace=True))
+
+
+def _make_coord(batch, height, width):
+    # relative position encoding
+    xv, yv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)])
+    xv_min = (xv.float() * 2 - width) / width
+    yv_min = (yv.float() * 2 - height) / height
+    xv_max = ((xv + 1).float() * 2 - width) / width
+    yv_max = ((yv + 1).float() * 2 - height) / height
+    xv_ctr = (xv_min + xv_max) / 2
+    yv_ctr = (yv_min + yv_max) / 2
+    hmap = torch.ones(height, width) * (1. / height)
+    wmap = torch.ones(height, width) * (1. / width)
+    coord = torch.autograd.Variable(torch.cat([xv_min.unsqueeze(0), yv_min.unsqueeze(0), \
+                                               xv_max.unsqueeze(0), yv_max.unsqueeze(0), \
+                                               xv_ctr.unsqueeze(0), yv_ctr.unsqueeze(0), \
+                                               hmap.unsqueeze(0), wmap.unsqueeze(0)], dim=0))
+    coord = coord.unsqueeze(0).repeat(batch, 1, 1, 1)
+    return coord
+
+
+def l1norm(X, dim, eps=1e-8):
+    """L1-normalize columns of X
+    """
+    norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
+    X = torch.div(X, norm)
+    return X
+
+
+def l2norm(X, dim, eps=1e-8):
+    """L2-normalize columns of X
+    """
+    norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
+    X = torch.div(X, norm)
+    return X
+
+
+def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
+    """
+    query: (n_context, queryL, d)
+    context: (n_context, sourceL, d)
+    """
+    batch_size_q, queryL = query.size(0), query.size(1)
+    batch_size, sourceL = context.size(0), context.size(1)
+
+    # Get attention
+    # --> (batch, d, queryL)
+    queryT = torch.transpose(query, 1, 2)
+
+    # (batch, sourceL, d)(batch, d, queryL)
+    # --> (batch, sourceL, queryL)
+    attn = torch.bmm(context, queryT)
+    if raw_feature_norm == "softmax":
+        # --> (batch*sourceL, queryL)
+        attn = attn.view(batch_size * sourceL, queryL)
+        attn = nn.Softmax()(attn)
+        # --> (batch, sourceL, queryL)
+        attn = attn.view(batch_size, sourceL, queryL)
+    elif raw_feature_norm == "l2norm":
+        attn = l2norm(attn, 2)
+    elif raw_feature_norm == "clipped_l2norm":
+        attn = nn.LeakyReLU(0.1)(attn)
+        attn = l2norm(attn, 2)
+    else:
+        raise ValueError("unknown first norm type:", raw_feature_norm)
+    # --> (batch, queryL, sourceL)
+    attn = torch.transpose(attn, 1, 2).contiguous()
+    # --> (batch*queryL, sourceL)
+    attn = attn.view(batch_size * queryL, sourceL)
+    attn = nn.Softmax()(attn * smooth)
+    # --> (batch, queryL, sourceL)
+    attn = attn.view(batch_size, queryL, sourceL)
+    # --> (batch, sourceL, queryL)
+    attnT = torch.transpose(attn, 1, 2).contiguous()
+
+    # --> (batch, d, sourceL)
+    contextT = torch.transpose(context, 1, 2)
+    # (batch x d x sourceL)(batch x sourceL x queryL)
+    # --> (batch, d, queryL)
+    weightedContext = torch.bmm(contextT, attnT)
+    # --> (batch, queryL, d)
+    weightedContext = torch.transpose(weightedContext, 1, 2)
+
+    return weightedContext, attnT
+
+
+class BiMultiHeadAttention(nn.Module):
+    def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
+        super(BiMultiHeadAttention, self).__init__()
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.head_dim = embed_dim // num_heads
+        self.v_dim = v_dim
+        self.l_dim = l_dim
+
+        assert (
+                self.head_dim * self.num_heads == self.embed_dim
+        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+        self.scale = self.head_dim ** (-0.5)
+        self.dropout = dropout
+
+        self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
+        self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
+        self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
+        self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
+
+        self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
+        self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
+
+        self.stable_softmax_2d = cfg.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D
+        self.clamp_min_for_underflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW
+        self.clamp_max_for_overflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW
+
+        self._reset_parameters()
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def _reset_parameters(self):
+        nn.init.xavier_uniform_(self.v_proj.weight)
+        self.v_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.l_proj.weight)
+        self.l_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.values_v_proj.weight)
+        self.values_v_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.values_l_proj.weight)
+        self.values_l_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.out_v_proj.weight)
+        self.out_v_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.out_l_proj.weight)
+        self.out_l_proj.bias.data.fill_(0)
+
+    def forward(self, v, l, attention_mask_l=None):
+        bsz, tgt_len, embed_dim = v.size()
+
+        query_states = self.v_proj(v) * self.scale
+        key_states = self._shape(self.l_proj(l), -1, bsz)
+        value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
+        value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.view(*proj_shape)
+        value_v_states = value_v_states.view(*proj_shape)
+        value_l_states = value_l_states.view(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+            )
+
+        # attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
+
+        if self.stable_softmax_2d:
+            attn_weights = attn_weights - attn_weights.max()
+        
+        if self.clamp_min_for_underflow:
+            attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
+        if self.clamp_max_for_overflow:
+            attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
+
+        attn_weights_T = attn_weights.transpose(1, 2)
+        attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
+            0])
+        if self.clamp_min_for_underflow:
+            attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
+        if self.clamp_max_for_overflow:
+            attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
+
+        attn_weights_l = attn_weights_l.softmax(dim=-1)
+
+        if attention_mask_l is not None:
+            assert (attention_mask_l.dim() == 2)
+            attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1)
+            attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
+            attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
+
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
+
+        attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
+        attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
+
+        attn_output_v = torch.bmm(attn_probs_v, value_l_states)
+        attn_output_l = torch.bmm(attn_probs_l, value_v_states)
+
+
+        if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
+            )
+
+        if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
+            )
+
+        attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output_v = attn_output_v.transpose(1, 2)
+        attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
+        attn_output_l = attn_output_l.transpose(1, 2)
+        attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
+
+        attn_output_v = self.out_v_proj(attn_output_v)
+        attn_output_l = self.out_l_proj(attn_output_l)
+
+        return attn_output_v, attn_output_l
+
+
+# Bi-Direction MHA (text->image, image->text)
+class BiAttentionBlock(nn.Module):
+    def __init__(self, v_dim, l_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1,
+                 drop_path=.0, init_values=1e-4, cfg=None):
+        """
+        Inputs:
+            embed_dim - Dimensionality of input and attention feature vectors
+            hidden_dim - Dimensionality of hidden layer in feed-forward network
+                         (usually 2-4x larger than embed_dim)
+            num_heads - Number of heads to use in the Multi-Head Attention block
+            dropout - Amount of dropout to apply in the feed-forward network
+        """
+        super(BiAttentionBlock, self).__init__()
+
+        # pre layer norm
+        self.layer_norm_v = nn.LayerNorm(v_dim)
+        self.layer_norm_l = nn.LayerNorm(l_dim)
+        self.attn = BiMultiHeadAttention(v_dim=v_dim,
+                                         l_dim=l_dim,
+                                         embed_dim=embed_dim,
+                                         num_heads=num_heads,
+                                         dropout=dropout,
+                                         cfg=cfg)
+
+        # add layer scale for training stability
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
+        self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
+
+    def forward(self, v, l, attention_mask_l=None, dummy_tensor=None):
+        v = self.layer_norm_v(v)
+        l = self.layer_norm_l(l)
+        delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
+        # v, l = v + delta_v, l + delta_l
+        v = v + self.drop_path(self.gamma_v * delta_v)
+        l = l + self.drop_path(self.gamma_l * delta_l)
+        return v, l
+
+class BiAttentionBlockForCheckpoint(nn.Module):
+    def __init__(self, v_dim, l_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1,
+                 drop_path=.0, init_values=1e-4, cfg=None):
+        """
+        Inputs:
+            embed_dim - Dimensionality of input and attention feature vectors
+            hidden_dim - Dimensionality of hidden layer in feed-forward network
+                         (usually 2-4x larger than embed_dim)
+            num_heads - Number of heads to use in the Multi-Head Attention block
+            dropout - Amount of dropout to apply in the feed-forward network
+        """
+        super(BiAttentionBlockForCheckpoint, self).__init__()
+
+        # pre layer norm
+        self.layer_norm_v = nn.LayerNorm(v_dim)
+        self.layer_norm_l = nn.LayerNorm(l_dim)
+        self.attn = BiMultiHeadAttention(v_dim=v_dim,
+                                         l_dim=l_dim,
+                                         embed_dim=embed_dim,
+                                         num_heads=num_heads,
+                                         dropout=dropout,
+                                         cfg=cfg)
+
+        # add layer scale for training stability
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
+        self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
+
+        self.cfg = cfg
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL:
+            if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT:
+                self.shrink_lang = FeatureResizer(l_dim * 5, l_dim, 0.1)
+
+    def forward(self, q0, q1, q2, q3, q4, l, attention_mask_l=None, dummy_tensor=None):
+
+        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL:
+            visu_feat = []
+            lang_feat = []
+            for ii, feat in enumerate([q0, q1, q2, q3, q4]):
+                bs, _, h, w = feat.shape
+                q = feat.flatten(2).transpose(1, 2)
+                
+                new_v, new_l = self.single_attention_call(q, l, attention_mask_l=attention_mask_l)
+                new_v = new_v.transpose(1, 2).contiguous().view(bs, -1, h, w)
+                lang_feat.append(new_l)
+                visu_feat.append(new_v)
+            if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT:
+                pass
+            else:
+                lang_feat = self.shrink_lang(torch.cat(lang_feat, dim = -1)) # From multiple dimensions
+                lang_feat = [lang_feat, None, None, None, None]
+        else:
+            visu_feat = []
+            size_per_level, visual_features_flatten = [], []
+            for ii, feat_per_level in enumerate([q0, q1, q2, q3, q4]):
+                bs, c, h, w = feat_per_level.shape
+                size_per_level.append([h, w])
+                feat = permute_and_flatten(feat_per_level, bs, 1, c, h, w)
+                visual_features_flatten.append(feat)
+            visual_features_flatten = cat(visual_features_flatten, dim=1)
+            new_v, new_l = self.single_attention_call(visual_features_flatten, l, attention_mask_l=attention_mask_l)
+            # [bs, N, C] -> [bs, C, N]
+            new_v = new_v.transpose(1, 2).contiguous()
+
+            start = 0
+            for (h, w) in size_per_level:
+                new_v_per_level = new_v[:, :, start:start + h * w].view(bs, -1, h, w).contiguous()
+                visu_feat.append(new_v_per_level)
+                start += h * w
+            
+            lang_feat = [new_l, None, None, None, None]
+
+        return visu_feat[0], visu_feat[1], visu_feat[2], visu_feat[3], visu_feat[4], lang_feat[0], lang_feat[1], lang_feat[2], lang_feat[3], lang_feat[4]
+
+    
+    def single_attention_call(self, v, l, attention_mask_l=None, dummy_tensor=None):
+        v = self.layer_norm_v(v)
+        l = self.layer_norm_l(l)
+        delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
+        # v, l = v + delta_v, l + delta_l
+        v = v + self.drop_path(self.gamma_v * delta_v)
+        l = l + self.drop_path(self.gamma_l * delta_l)
+        return v, l
+
+
+# Single Direction MHA
+class MultiHeadAttention(nn.Module):
+    """
+    Multi-head attention module for both image and text
+    """
+
+    def __init__(self, q_dim, k_dim, embed_dim, num_heads, dropout=0.1, 
+        clamp_min_for_underflow = False, clamp_max_for_overflow = False):
+        super(MultiHeadAttention, self).__init__()
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.head_dim = embed_dim // num_heads
+        self.q_dim = q_dim
+        self.k_dim = k_dim
+
+        assert (
+                self.head_dim * self.num_heads == self.embed_dim
+        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+        self.scale = self.head_dim ** (-0.5)
+        self.dropout = dropout
+
+        self.q_proj = nn.Linear(self.q_dim, self.embed_dim)
+        self.k_proj = nn.Linear(self.k_dim, self.embed_dim)
+        self.v_proj = nn.Linear(self.k_dim, self.embed_dim)
+        self.out_proj = nn.Linear(self.embed_dim, self.q_dim)
+        self.clamp_min_for_underflow = clamp_min_for_underflow
+        self.clamp_max_for_overflow = clamp_max_for_overflow
+
+        self._reset_parameters()
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def _reset_parameters(self):
+        nn.init.xavier_uniform_(self.q_proj.weight)
+        self.q_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.k_proj.weight)
+        self.k_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.v_proj.weight)
+        self.v_proj.bias.data.fill_(0)
+        nn.init.xavier_uniform_(self.out_proj.weight)
+        self.out_proj.bias.data.fill_(0)
+
+    def forward(self, q, k, v, attention_mask=None, return_attention=False):
+        bsz, tgt_len, embed_dim = q.size()
+
+        query_states = self.q_proj(q) * self.scale
+        key_states = self._shape(self.k_proj(k), -1, bsz)
+        value_states = self._shape(self.v_proj(v), -1, bsz)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.view(*proj_shape)
+        value_states = value_states.view(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+            )
+
+        if self.clamp_min_for_underflow:
+            attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
+        if self.clamp_max_for_overflow:
+            attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
+
+        if attention_mask is not None:
+            # [bsz, src_len]
+            assert (attention_mask.dim() == 2)
+            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+            attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
+            attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
+
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if return_attention:
+            # this operation is a bit akward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+
+        return attn_output, attn_weights
+
+
+class AttentionMLP(nn.Module):
+    def __init__(self, q_dim, hidden_dim, dropout=0.1):
+        super(AttentionMLP, self).__init__()
+        self.hidden_dim = hidden_dim
+        self.activation_fn = nn.GELU()
+        self.fc1 = nn.Linear(q_dim, hidden_dim)
+        self.fc2 = nn.Linear(hidden_dim, q_dim)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, hidden_states):
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        return hidden_states
+
+
+class AttentionT2I(nn.Module):
+    def __init__(self, q_dim, k_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1,
+                 drop_path=.0, init_values=1e-4, mode="i2t", use_layer_scale = False,
+                 clamp_min_for_underflow = False, clamp_max_for_overflow = False):
+        """
+        Inputs:
+            embed_dim - Dimensionality of input and attention feature vectors
+            hidden_dim - Dimensionality of hidden layer in feed-forward network
+                         (usually 2-4x larger than embed_dim)
+            num_heads - Number of heads to use in the Multi-Head Attention block
+            dropout - Amount of dropout to apply in the feed-forward network
+        """
+        super(AttentionT2I, self).__init__()
+
+        # pre_layer norm
+        self.layer_norm_q_1 = nn.LayerNorm(q_dim)
+        self.layer_norm_k_1 = nn.LayerNorm(k_dim)
+        self.attn = MultiHeadAttention(q_dim=q_dim,
+                                       k_dim=k_dim,
+                                       embed_dim=embed_dim,
+                                       num_heads=num_heads,
+                                       clamp_min_for_underflow=clamp_min_for_underflow,
+                                       clamp_max_for_overflow=clamp_max_for_overflow)
+        self.mode = mode
+
+        # add layer scale for training stability
+        self.use_layer_scale = use_layer_scale
+        if self.use_layer_scale:
+            self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+            self.gamma = nn.Parameter(init_values * torch.ones((q_dim)), requires_grad=True)
+
+
+    def forward(self, q0, q1, q2, q3, q4, k, v, attention_mask, dummy_arg=None):
+        qs = []
+        for q_index, q in enumerate([q0, q1, q2, q3, q4]):
+            bs, _, h, w = q.shape
+            # (batch, seq_len, embed_size)
+            q = q.flatten(2).transpose(1, 2)
+            q = self.layer_norm_q_1(q)
+            k, v = self.layer_norm_k_1(k), self.layer_norm_k_1(v)
+            delta_q = self.attn(q, k, v, attention_mask=attention_mask)[0]
+            if self.use_layer_scale:
+                q = q + self.drop_path(self.gamma * delta_q)
+            else:
+                q = q + delta_q
+            q = q.transpose(1, 2).contiguous().view(bs, -1, h, w)
+            qs.append(q)
+
+
+        return qs[0], qs[1], qs[2], qs[3], qs[4]
diff --git a/maskrcnn_benchmark/utils/imports.py b/maskrcnn_benchmark/utils/imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..081e5556f74f0068957f4514593ca7446652d546
--- /dev/null
+++ b/maskrcnn_benchmark/utils/imports.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+if torch._six.PY37:
+    import importlib
+    import importlib.util
+    import sys
+
+
+    # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
+    def import_file(module_name, file_path, make_importable=False):
+        spec = importlib.util.spec_from_file_location(module_name, file_path)
+        module = importlib.util.module_from_spec(spec)
+        spec.loader.exec_module(module)
+        if make_importable:
+            sys.modules[module_name] = module
+        return module
+else:
+    import imp
+
+    def import_file(module_name, file_path, make_importable=None):
+        module = imp.load_source(module_name, file_path)
+        return module
diff --git a/maskrcnn_benchmark/utils/logger.py b/maskrcnn_benchmark/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30fa8d49a67111cb7d8d47e7db1ece98134aa8e
--- /dev/null
+++ b/maskrcnn_benchmark/utils/logger.py
@@ -0,0 +1,25 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import os
+import sys
+
+
+def setup_logger(name, save_dir, distributed_rank):
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.DEBUG)
+    # don't log results for the non-master process
+    if distributed_rank > 0:
+        return logger
+    ch = logging.StreamHandler(stream=sys.stdout)
+    ch.setLevel(logging.DEBUG)
+    formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
+    ch.setFormatter(formatter)
+    logger.addHandler(ch)
+
+    if save_dir:
+        fh = logging.FileHandler(os.path.join(save_dir, "log.txt"))
+        fh.setLevel(logging.DEBUG)
+        fh.setFormatter(formatter)
+        logger.addHandler(fh)
+
+    return logger
diff --git a/maskrcnn_benchmark/utils/mdetr_dist.py b/maskrcnn_benchmark/utils/mdetr_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..af8f19fd511db7b871e78abf0e64d1225994406d
--- /dev/null
+++ b/maskrcnn_benchmark/utils/mdetr_dist.py
@@ -0,0 +1,229 @@
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities related to distributed mode.
+
+By default, the reduce of metrics and such are done on GPU, since it's more straightforward (we reuse the NCCL backend)
+If you want to reduce on CPU instead (required for big datasets like GQA), use the env variable MDETR_CPU_REDUCE=1
+"""
+import functools
+import io
+import os
+import datetime
+
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+    """
+    Return a process group based on gloo backend, containing all the ranks
+    The result is cached.
+    """
+
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+
+    return dist.group.WORLD
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    cpu_group = None
+    if os.getenv("MDETR_CPU_REDUCE") == "1":
+        cpu_group = _get_global_gloo_group()
+
+    buffer = io.BytesIO()
+    torch.save(data, buffer)
+    data_view = buffer.getbuffer()
+    device = "cuda" if cpu_group is None else "cpu"
+    tensor = torch.ByteTensor(data_view).to(device)
+
+    # obtain Tensor size of each rank
+    local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
+    size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
+    if cpu_group is None:
+        dist.all_gather(size_list, local_size)
+    else:
+        print("gathering on cpu")
+        dist.all_gather(size_list, local_size, group=cpu_group)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+    assert isinstance(local_size.item(), int)
+    local_size = int(local_size.item())
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
+    if local_size != max_size:
+        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
+        tensor = torch.cat((tensor, padding), dim=0)
+    if cpu_group is None:
+        dist.all_gather(tensor_list, tensor)
+    else:
+        dist.all_gather(tensor_list, tensor, group=cpu_group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
+        buffer = io.BytesIO(tensor.cpu().numpy())
+        obj = torch.load(buffer)
+        data_list.append(obj)
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    """
+    Returns:
+        True if distributed training is enabled
+    """
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    """
+    Returns:
+        The number of processes in the process group
+    """
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    """
+    Returns:
+        The rank of the current process within the global process group.
+    """
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    """
+    Returns:
+        The rank of the current process within the local (per-machine) process group.
+    """
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    assert _LOCAL_PROCESS_GROUP is not None
+    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+    """
+    Returns:
+        The size of the per-machine process group,
+        i.e. the number of processes per machine.
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process():
+    """Return true if the current process is the main one"""
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    """Utility function to save only from the main process"""
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    """Initialize distributed training, if appropriate"""
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
+
+    dist.init_process_group(
+        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank,
+        timeout=datetime.timedelta(0, 7200)
+    )
+    dist.barrier()
+    setup_for_distributed(args.debug or args.rank == 0)
diff --git a/maskrcnn_benchmark/utils/metric_logger.py b/maskrcnn_benchmark/utils/metric_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..b506f40f1d7e912389c9c27dafd4d340552a6a9c
--- /dev/null
+++ b/maskrcnn_benchmark/utils/metric_logger.py
@@ -0,0 +1,130 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from collections import defaultdict
+from collections import deque
+
+import torch
+import time
+from datetime import datetime
+from .comm import is_main_process
+
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20):
+        self.deque = deque(maxlen=window_size)
+        # self.series = []
+        self.total = 0.0
+        self.count = 0
+
+    def update(self, value):
+        self.deque.append(value)
+        # self.series.append(value)
+        self.count += 1
+        if value != value:
+            value = 0
+        self.total += value
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque))
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
+            )
+        return self.delimiter.join(loss_str)
+
+
+# haotian added tensorboard support
+class TensorboardLogger(MetricLogger):
+    def __init__(self,
+                 log_dir,
+                 start_iter=0,
+                 delimiter='\t'
+                 ):
+        super(TensorboardLogger, self).__init__(delimiter)
+        self.iteration = start_iter
+        self.writer = self._get_tensorboard_writer(log_dir)
+
+    @staticmethod
+    def _get_tensorboard_writer(log_dir):
+        try:
+            from tensorboardX import SummaryWriter
+        except ImportError:
+            raise ImportError(
+                'To use tensorboard please install tensorboardX '
+                '[ pip install tensorflow tensorboardX ].'
+            )
+
+        if is_main_process():
+            # timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M')
+            tb_logger = SummaryWriter('{}'.format(log_dir))
+            return tb_logger
+        else:
+            return None
+
+    def update(self, **kwargs):
+        super(TensorboardLogger, self).update(**kwargs)
+        if self.writer:
+            for k, v in kwargs.items():
+                if isinstance(v, torch.Tensor):
+                    v = v.item()
+                assert isinstance(v, (float, int))
+                self.writer.add_scalar(k, v, self.iteration)
+
+            self.iteration += 1
diff --git a/maskrcnn_benchmark/utils/miscellaneous.py b/maskrcnn_benchmark/utils/miscellaneous.py
new file mode 100644
index 0000000000000000000000000000000000000000..0169648926c729c217520442cd59a9214975a3bb
--- /dev/null
+++ b/maskrcnn_benchmark/utils/miscellaneous.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import errno
+import os
+from .comm import is_main_process
+
+def mkdir(path):
+    try:
+        os.makedirs(path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+
+
+def save_config(cfg, path):
+    if is_main_process():
+        with open(path, 'w') as f:
+            f.write(cfg.dump())
diff --git a/maskrcnn_benchmark/utils/model_serialization.py b/maskrcnn_benchmark/utils/model_serialization.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8707ceb92f045d78e756c0a00df7e0192c39f1e
--- /dev/null
+++ b/maskrcnn_benchmark/utils/model_serialization.py
@@ -0,0 +1,157 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from collections import OrderedDict, defaultdict
+import logging
+import math
+import torch
+
+from maskrcnn_benchmark.utils.imports import import_file
+
+def resize_2d(posemb, shape_new):
+    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+    ntok_new = shape_new[0]
+    gs_old = int(math.sqrt(len(posemb)))  # 2 * w - 1
+    gs_new = int(math.sqrt(ntok_new))  # 2 * w - 1
+    posemb_grid = posemb.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(gs_new * gs_new, -1)
+    return posemb_grid
+
+def align_and_update_state_dicts(model_state_dict, loaded_state_dict, reshape_keys=['pos_bias_table'], use_weightmap=False):
+    """
+    Strategy: suppose that the models that we will create will have prefixes appended
+    to each of its keys, for example due to an extra level of nesting that the original
+    pre-trained weights from ImageNet won't contain. For example, model.state_dict()
+    might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+    res2.conv1.weight. We thus want to match both parameters together.
+    For that, we look for each model weight, look among all loaded keys if there is one
+    that is a suffix of the current weight name, and use it if that's the case.
+    If multiple matches exist, take the one with longest size
+    of the corresponding name. For example, for the same model as before, the pretrained
+    weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
+    we want to match backbone[0].body.conv1.weight to conv1.weight, and
+    backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+    """
+    current_keys = sorted(list(model_state_dict.keys()))
+    loaded_keys = sorted(list(loaded_state_dict.keys()))
+    # get a matrix of string matches, where each (i, j) entry correspond to the size of the
+    # loaded_key string, if it matches
+    match_matrix = [
+        len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
+    ]
+    match_matrix = torch.as_tensor(match_matrix).view(
+        len(current_keys), len(loaded_keys)
+    )
+    max_match_size, idxs = match_matrix.max(1)
+    # remove indices that correspond to no-match
+    idxs[max_match_size == 0] = -1
+
+    matched_keys = []
+    # used for logging
+    max_size = max([len(key) for key in current_keys]) if current_keys else 1
+    max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
+    log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
+    logger = logging.getLogger(__name__)
+    for idx_new, idx_old in enumerate(idxs.tolist()):
+        if idx_old == -1:
+            continue
+        key = current_keys[idx_new]
+        key_old = loaded_keys[idx_old]
+        if model_state_dict[key].shape != loaded_state_dict[key_old].shape:
+            if any([k in key_old for k in reshape_keys]):
+                new_shape = model_state_dict[key].shape
+                logger.warning('Reshaping {} -> {}. \n'.format(key_old, key))
+                model_state_dict[key] = resize_2d(loaded_state_dict[key_old], new_shape)
+            elif use_weightmap and 'cls_logits' in key:
+                coco_in_objects365_inds = [
+                    227, 26, 55, 202, 2, 44, 338, 346, 32, 336, 118, 299, 218,
+                    25, 361, 59, 95, 161, 278, 82, 110, 22, 364, 134, 9, 350,
+                    152, 323, 304, 130, 285, 289, 16, 172, 17, 18, 283, 305,
+                    321, 35, 362, 88, 127, 174, 292, 37, 11, 6, 267, 212, 41,
+                    58, 162, 237, 98, 48, 63, 81, 247, 23, 94, 326, 349, 178,
+                    203, 259, 171, 60, 198, 213, 325, 282, 258, 33, 71, 353,
+                    273, 318, 148, 330
+                ]
+                logger.info("Use coco_in_objects365_inds labelmap for COCO detection because of size mis-match, "
+                      "Reshaping {} -> {}. \n".format(key_old, key))
+                new_shape = model_state_dict[key].shape
+                assert new_shape[0] == len(coco_in_objects365_inds)
+                weight_inds_old = torch.as_tensor(coco_in_objects365_inds).to(loaded_state_dict[key_old].device)
+                model_state_dict[key] = loaded_state_dict[key_old][weight_inds_old].to(model_state_dict[key].device)
+            else:
+                logger.info('Skip due to size mismatch: {} -> {}. \n'.format(key_old, key))
+                continue
+        else:
+            model_state_dict[key] = loaded_state_dict[key_old]
+        matched_keys.append(key)
+        logger.info(
+            log_str_template.format(
+                key,
+                max_size,
+                key_old,
+                max_size_loaded,
+                tuple(loaded_state_dict[key_old].shape),
+            )
+        )
+    missing_keys = set(current_keys)-set(matched_keys)
+    if len(missing_keys):
+        groups = _group_checkpoint_keys(missing_keys)
+        msg_per_group = sorted(k + _group_to_str(v) for k, v in groups.items())
+        msg = '\n'.join(sorted(msg_per_group))
+        logger.warning('Some layers unloaded with pre-trained weight: \n' + msg)
+
+def strip_prefix_if_present(state_dict, prefix):
+    keys = sorted(state_dict.keys())
+    if not all(key.startswith(prefix) for key in keys):
+        return state_dict
+    stripped_state_dict = OrderedDict()
+    for key, value in state_dict.items():
+        stripped_state_dict[key.replace(prefix, "", 1)] = value
+    return stripped_state_dict
+
+def load_state_dict(model, loaded_state_dict):
+    model_state_dict = model.state_dict()
+    # if the state_dict comes from a model that was wrapped in a
+    # DataParallel or DistributedDataParallel during serialization,
+    # remove the "module" prefix before performing the matching
+    loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
+    align_and_update_state_dicts(model_state_dict, loaded_state_dict)
+
+    # use strict loading
+    model.load_state_dict(model_state_dict)
+
+def _group_checkpoint_keys(keys):
+    """
+    Group keys based on common prefixes. A prefix is the string up to the final
+    "." in each key.
+    Args:
+        keys (list[str]): list of parameter names, i.e. keys in the model
+            checkpoint dict.
+    Returns:
+        dict[list]: keys with common prefixes are grouped into lists.
+    """
+    groups = defaultdict(list)
+    for key in keys:
+        pos = key.rfind(".")
+        if pos >= 0:
+            head, tail = key[:pos], [key[pos + 1 :]]
+        else:
+            head, tail = key, []
+        groups[head].extend(tail)
+    return groups
+
+def _group_to_str(group):
+    """
+    Format a group of parameter name suffixes into a loggable string.
+    Args:
+        group (list[str]): list of parameter name suffixes.
+    Returns:
+        str: formated string.
+    """
+    if len(group) == 0:
+        return ""
+
+    if len(group) == 1:
+        return "." + group[0]
+
+    return ".{" + ", ".join(sorted(group)) + "}"
\ No newline at end of file
diff --git a/maskrcnn_benchmark/utils/model_zoo.py b/maskrcnn_benchmark/utils/model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..96aef6fda6cec21c074c33fd1e3934cf52088e60
--- /dev/null
+++ b/maskrcnn_benchmark/utils/model_zoo.py
@@ -0,0 +1,61 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
+import sys
+
+try:
+    from torch.hub import _download_url_to_file
+    from torch.hub import urlparse
+    from torch.hub import HASH_REGEX
+except ImportError:
+    from torch.utils.model_zoo import _download_url_to_file
+    from torch.utils.model_zoo import urlparse
+    from torch.utils.model_zoo import HASH_REGEX
+
+from maskrcnn_benchmark.utils.comm import is_main_process
+from maskrcnn_benchmark.utils.comm import synchronize
+
+
+# very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py
+# but with a few improvements and modifications
+def cache_url(url, model_dir='model', progress=True):
+    r"""Loads the Torch serialized object at the given URL.
+    If the object is already present in `model_dir`, it's deserialized and
+    returned. The filename part of the URL should follow the naming convention
+    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
+    digits of the SHA256 hash of the contents of the file. The hash is used to
+    ensure unique names and to verify the contents of the file.
+    The default value of `model_dir` is ``$TORCH_HOME/models`` where
+    ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
+    overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
+    Args:
+        url (string): URL of the object to download
+        model_dir (string, optional): directory in which to save the object
+        progress (bool, optional): whether or not to display a progress bar to stderr
+    Example:
+        >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
+    """
+    if model_dir is None:
+        torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
+        model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir, exist_ok=True)
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if filename == "model_final.pkl":
+        # workaround as pre-trained Caffe2 models from Detectron have all the same filename
+        # so make the full path the filename by replacing / with _
+        filename = parts.path.replace("/", "_")
+    cached_file = os.path.join(model_dir, filename)
+    if not os.path.exists(cached_file):
+        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+        hash_prefix = HASH_REGEX.search(filename)
+        if hash_prefix is not None:
+            hash_prefix = hash_prefix.group(1)
+            # workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
+            # which matches the hash PyTorch uses. So we skip the hash matching
+            # if the hash_prefix is less than 6 characters
+            if len(hash_prefix) < 6:
+                hash_prefix = None
+        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+    synchronize()
+    return cached_file
diff --git a/maskrcnn_benchmark/utils/pretrain_model_loading.py b/maskrcnn_benchmark/utils/pretrain_model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45f05a5e3baa225a9e628ddcc3bfc0a0eefc238
--- /dev/null
+++ b/maskrcnn_benchmark/utils/pretrain_model_loading.py
@@ -0,0 +1,49 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from collections import OrderedDict
+
+def _remove_bn_statics(state_dict):
+    layer_keys = sorted(state_dict.keys())
+    remove_list = []
+    for key in layer_keys:
+        if 'running_mean' in key or 'running_var' in key or 'num_batches_tracked' in key:
+            remove_list.append(key)
+    for key in remove_list:
+        del state_dict[key]
+    return state_dict
+
+def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
+    import re
+    layer_keys = sorted(state_dict.keys())
+    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
+        if not stage_with_dcn:
+            continue
+        for old_key in layer_keys:
+            pattern = ".*layer{}.*conv2.*".format(ix)
+            r = re.match(pattern, old_key)
+            if r is None:
+                continue
+            for param in ["weight", "bias"]:
+                if old_key.find(param) is -1:
+                    continue
+                if 'unit01' in old_key:
+                    continue
+                new_key = old_key.replace(
+                    "conv2.{}".format(param), "conv2.conv.{}".format(param)
+                )
+                print("pattern: {}, old_key: {}, new_key: {}".format(
+                    pattern, old_key, new_key
+                ))
+                state_dict[new_key] = state_dict[old_key]
+                del state_dict[old_key]
+    return state_dict
+
+
+def load_pretrain_format(cfg, f):
+    model = torch.load(f)
+    model = _remove_bn_statics(model)
+    model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)
+
+    return dict(model=model)
diff --git a/maskrcnn_benchmark/utils/registry.py b/maskrcnn_benchmark/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae82dfb879f19ca0c3d9056abdb440b0863cb912
--- /dev/null
+++ b/maskrcnn_benchmark/utils/registry.py
@@ -0,0 +1,45 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+
+def _register_generic(module_dict, module_name, module):
+    assert module_name not in module_dict
+    module_dict[module_name] = module
+
+
+class Registry(dict):
+    '''
+    A helper class for managing registering modules, it extends a dictionary
+    and provides a register functions.
+
+    Eg. creeting a registry:
+        some_registry = Registry({"default": default_module})
+
+    There're two ways of registering new modules:
+    1): normal way is just calling register function:
+        def foo():
+            ...
+        some_registry.register("foo_module", foo)
+    2): used as decorator when declaring the module:
+        @some_registry.register("foo_module")
+        @some_registry.register("foo_modeul_nickname")
+        def foo():
+            ...
+
+    Access of module is just like using a dictionary, eg:
+        f = some_registry["foo_modeul"]
+    '''
+    def __init__(self, *args, **kwargs):
+        super(Registry, self).__init__(*args, **kwargs)
+
+    def register(self, module_name, module=None):
+        # used as function call
+        if module is not None:
+            _register_generic(self, module_name, module)
+            return
+
+        # used as decorator
+        def register_fn(fn):
+            _register_generic(self, module_name, fn)
+            return fn
+
+        return register_fn
diff --git a/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py b/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..027fb4598529c0072f670a4776f2c825968f5caf
--- /dev/null
+++ b/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py
@@ -0,0 +1,58 @@
+import torch
+import maskrcnn_benchmark.utils.dist as dist
+
+
+def normalized_positive_map(positive_map):
+    positive_map = positive_map.float()
+    positive_map_num_pos = positive_map.sum(2)
+    positive_map_num_pos[positive_map_num_pos == 0] = 1e-6
+    positive_map = positive_map / positive_map_num_pos.unsqueeze(-1)
+    return positive_map
+
+
+def pad_tensor_given_dim_length(tensor, dim, length, padding_value=0, batch_first=True):
+    new_size = list(tensor.size()[:dim]) + [length] + list(tensor.size()[dim + 1:])
+    out_tensor = tensor.data.new(*new_size).fill_(padding_value)
+    if batch_first:
+        out_tensor[:, :tensor.size(1), ...] = tensor
+    else:
+        out_tensor[:tensor.size(0), ...] = tensor
+    return out_tensor
+
+
+def pad_random_negative_tensor_given_length(positive_tensor, negative_padding_tensor, length=None):
+    assert positive_tensor.shape[0] + negative_padding_tensor.shape[0] == length
+    return torch.cat((positive_tensor, negative_padding_tensor), dim=0)
+
+
+def gather_tensors(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    if not dist.is_dist_avail_and_initialized():
+        return torch.stack([tensor], dim=0)
+
+    total = dist.get_world_size()
+    rank = torch.distributed.get_rank()
+    # gathered_normalized_img_emb = [torch.zeros_like(normalized_img_emb) for _ in range(total)]
+    # torch.distributed.all_gather(gathered_normalized_img_emb, normalized_img_emb)
+
+    tensors_gather = [
+        torch.zeros_like(tensor)
+        for _ in range(total)
+    ]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    # need to do this to restore propagation of the gradients
+    tensors_gather[rank] = tensor
+    output = torch.stack(tensors_gather, dim=0)
+    return output
+
+
+def convert_to_roi_format(boxes):
+    concat_boxes = boxes.bbox
+    device, dtype = concat_boxes.device, concat_boxes.dtype
+    ids = torch.full((len(boxes), 1), 0, dtype=dtype, device=device)
+    rois = torch.cat([ids, concat_boxes], dim=1)
+    return rois
\ No newline at end of file
diff --git a/maskrcnn_benchmark/utils/stats.py b/maskrcnn_benchmark/utils/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae04f1e20b44d2774f73b698ea91a00e7b0ce690
--- /dev/null
+++ b/maskrcnn_benchmark/utils/stats.py
@@ -0,0 +1,510 @@
+'''
+Copyright (C) 2019 Sovrasov V. - All Rights Reserved
+ * You may use, distribute and modify this code under the
+ * terms of the MIT license.
+ * You should have received a copy of the MIT license with
+ * this file. If not visit https://opensource.org/licenses/MIT
+'''
+
+import sys
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from maskrcnn_benchmark.layers import *
+
+def get_model_complexity_info(model, input_res,
+                              print_per_layer_stat=True,
+                              as_strings=True,
+                              input_constructor=None, ost=sys.stdout,
+                              verbose=False, ignore_modules=[],
+                              custom_modules_hooks={}):
+    assert type(input_res) is tuple
+    assert len(input_res) >= 1
+    assert isinstance(model, nn.Module)
+    global CUSTOM_MODULES_MAPPING
+    CUSTOM_MODULES_MAPPING = custom_modules_hooks
+    flops_model = add_flops_counting_methods(model)
+    flops_model.eval()
+    flops_model.start_flops_count(ost=ost, verbose=verbose,
+                                  ignore_list=ignore_modules)
+    if input_constructor:
+        input = input_constructor(input_res)
+        _ = flops_model(**input)
+    else:
+        try:
+            batch = torch.ones(()).new_empty((1, *input_res),
+                                             dtype=next(flops_model.parameters()).dtype,
+                                             device=next(flops_model.parameters()).device)
+        except StopIteration:
+            batch = torch.ones(()).new_empty((1, *input_res))
+
+        _ = flops_model(batch)
+
+    flops_count, params_count = flops_model.compute_average_flops_cost()
+    if print_per_layer_stat:
+        print_model_with_flops(flops_model, flops_count, params_count, ost=ost)
+    flops_model.stop_flops_count()
+    CUSTOM_MODULES_MAPPING = {}
+
+    if as_strings:
+        return flops_to_string(flops_count), params_to_string(params_count)
+
+    return flops_count, params_count
+
+
+def flops_to_string(flops, units='GMac', precision=2):
+    if units is None:
+        if flops // 10**9 > 0:
+            return str(round(flops / 10.**9, precision)) + ' GMac'
+        elif flops // 10**6 > 0:
+            return str(round(flops / 10.**6, precision)) + ' MMac'
+        elif flops // 10**3 > 0:
+            return str(round(flops / 10.**3, precision)) + ' KMac'
+        else:
+            return str(flops) + ' Mac'
+    else:
+        if units == 'GMac':
+            return str(round(flops / 10.**9, precision)) + ' ' + units
+        elif units == 'MMac':
+            return str(round(flops / 10.**6, precision)) + ' ' + units
+        elif units == 'KMac':
+            return str(round(flops / 10.**3, precision)) + ' ' + units
+        else:
+            return str(flops) + ' Mac'
+
+
+def params_to_string(params_num, units=None, precision=2):
+    if units is None:
+        if params_num // 10 ** 6 > 0:
+            return str(round(params_num / 10 ** 6, 2)) + ' M'
+        elif params_num // 10 ** 3:
+            return str(round(params_num / 10 ** 3, 2)) + ' k'
+        else:
+            return str(params_num)
+    else:
+        if units == 'M':
+            return str(round(params_num / 10.**6, precision)) + ' ' + units
+        elif units == 'K':
+            return str(round(params_num / 10.**3, precision)) + ' ' + units
+        else:
+            return str(params_num)
+
+
+def accumulate_flops(self):
+    if is_supported_instance(self):
+        return self.__flops__
+    else:
+        sum = 0
+        for m in self.children():
+            sum += m.accumulate_flops()
+        return sum
+
+
+def print_model_with_flops(model, total_flops, total_params, units='GMac',
+                           precision=3, ost=sys.stdout):
+
+    def accumulate_params(self):
+        if is_supported_instance(self):
+            return self.__params__
+        else:
+            sum = 0
+            for m in self.children():
+                sum += m.accumulate_params()
+            return sum
+
+    def flops_repr(self):
+        accumulated_params_num = self.accumulate_params()
+        accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__
+        return ', '.join([params_to_string(accumulated_params_num,
+                                           units='M', precision=precision),
+                          '{:.3%} Params'.format(accumulated_params_num / total_params),
+                          flops_to_string(accumulated_flops_cost,
+                                          units=units, precision=precision),
+                          '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
+                          self.original_extra_repr()])
+
+    def add_extra_repr(m):
+        m.accumulate_flops = accumulate_flops.__get__(m)
+        m.accumulate_params = accumulate_params.__get__(m)
+        flops_extra_repr = flops_repr.__get__(m)
+        if m.extra_repr != flops_extra_repr:
+            m.original_extra_repr = m.extra_repr
+            m.extra_repr = flops_extra_repr
+            assert m.extra_repr != m.original_extra_repr
+
+    def del_extra_repr(m):
+        if hasattr(m, 'original_extra_repr'):
+            m.extra_repr = m.original_extra_repr
+            del m.original_extra_repr
+        if hasattr(m, 'accumulate_flops'):
+            del m.accumulate_flops
+
+    model.apply(add_extra_repr)
+    print(repr(model), file=ost)
+    model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+    params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    return params_num
+
+
+def add_flops_counting_methods(net_main_module):
+    # adding additional methods to the existing module object,
+    # this is done this way so that each function has access to self object
+    net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
+    net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
+    net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
+    net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(
+                                                    net_main_module)
+
+    net_main_module.reset_flops_count()
+
+    return net_main_module
+
+
+def compute_average_flops_cost(self):
+    """
+    A method that will be available after add_flops_counting_methods() is called
+    on a desired net object.
+
+    Returns current mean flops consumption per image.
+
+    """
+
+    for m in self.modules():
+        m.accumulate_flops = accumulate_flops.__get__(m)
+
+    flops_sum = self.accumulate_flops()
+
+    for m in self.modules():
+        if hasattr(m, 'accumulate_flops'):
+            del m.accumulate_flops
+
+    params_sum = get_model_parameters_number(self)
+    return flops_sum / self.__batch_counter__, params_sum
+
+
+def start_flops_count(self, **kwargs):
+    """
+    A method that will be available after add_flops_counting_methods() is called
+    on a desired net object.
+
+    Activates the computation of mean flops consumption per image.
+    Call it before you run the network.
+
+    """
+    add_batch_counter_hook_function(self)
+
+    seen_types = set()
+
+    def add_flops_counter_hook_function(module, ost, verbose, ignore_list):
+        if type(module) in ignore_list:
+            seen_types.add(type(module))
+            if is_supported_instance(module):
+                module.__params__ = 0
+        elif is_supported_instance(module):
+            if hasattr(module, '__flops_handle__'):
+                return
+            if type(module) in CUSTOM_MODULES_MAPPING:
+                handle = module.register_forward_hook(
+                                        CUSTOM_MODULES_MAPPING[type(module)])
+            elif getattr(module, 'compute_macs', False):
+                handle = module.register_forward_hook(
+                    module.compute_macs
+                )
+            else:
+                handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
+            module.__flops_handle__ = handle
+            seen_types.add(type(module))
+        else:
+            if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \
+               not type(module) in seen_types:
+                print('Warning: module ' + type(module).__name__ +
+                      ' is treated as a zero-op.', file=ost)
+            seen_types.add(type(module))
+
+    self.apply(partial(add_flops_counter_hook_function, **kwargs))
+
+
+def stop_flops_count(self):
+    """
+    A method that will be available after add_flops_counting_methods() is called
+    on a desired net object.
+
+    Stops computing the mean flops consumption per image.
+    Call whenever you want to pause the computation.
+
+    """
+    remove_batch_counter_hook_function(self)
+    self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+    """
+    A method that will be available after add_flops_counting_methods() is called
+    on a desired net object.
+
+    Resets statistics computed so far.
+
+    """
+    add_batch_counter_variables_or_reset(self)
+    self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+    module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+    output_size = output[0]
+    batch_size = output_size.shape[0]
+    output_elements_count = batch_size
+    for val in output_size.shape[1:]:
+        output_elements_count *= val
+    module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+    active_elements_count = output.numel()
+    module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+    input = input[0]
+    # pytorch checks dimensions, so here we don't care much
+    output_last_dim = output.shape[-1]
+    bias_flops = output_last_dim if module.bias is not None else 0
+    module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)
+
+
+def pool_flops_counter_hook(module, input, output):
+    input = input[0]
+    module.__flops__ += int(np.prod(input.shape))
+
+
+def bn_flops_counter_hook(module, input, output):
+    input = input[0]
+
+    batch_flops = np.prod(input.shape)
+    if module.affine:
+        batch_flops *= 2
+    module.__flops__ += int(batch_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+    # Can have multiple inputs, getting the first one
+    input = input[0]
+
+    batch_size = input.shape[0]
+    output_dims = list(output.shape[2:])
+
+    kernel_dims = list(conv_module.kernel_size)
+    in_channels = conv_module.in_channels
+    out_channels = conv_module.out_channels
+    groups = conv_module.groups
+
+    filters_per_channel = out_channels // groups
+    conv_per_position_flops = int(np.prod(kernel_dims)) * \
+        in_channels * filters_per_channel
+
+    active_elements_count = batch_size * int(np.prod(output_dims))
+
+    overall_conv_flops = conv_per_position_flops * active_elements_count
+
+    bias_flops = 0
+
+    if conv_module.bias is not None:
+
+        bias_flops = out_channels * active_elements_count
+
+    overall_flops = overall_conv_flops + bias_flops
+
+    conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+    batch_size = 1
+    if len(input) > 0:
+        # Can have multiple inputs, getting the first one
+        input = input[0]
+        batch_size = len(input)
+    else:
+        pass
+        print('Warning! No positional inputs found for a module,'
+              ' assuming batch size is 1.')
+    module.__batch_counter__ += batch_size
+
+
+def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
+    # matrix matrix mult ih state and internal state
+    flops += w_ih.shape[0]*w_ih.shape[1]
+    # matrix matrix mult hh state and internal state
+    flops += w_hh.shape[0]*w_hh.shape[1]
+    if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
+        # add both operations
+        flops += rnn_module.hidden_size
+    elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
+        # hadamard of r
+        flops += rnn_module.hidden_size
+        # adding operations from both states
+        flops += rnn_module.hidden_size*3
+        # last two hadamard product and add
+        flops += rnn_module.hidden_size*3
+    elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
+        # adding operations from both states
+        flops += rnn_module.hidden_size*4
+        # two hadamard product and add for C state
+        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
+        # final hadamard
+        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
+    return flops
+
+
+def rnn_flops_counter_hook(rnn_module, input, output):
+    """
+    Takes into account batch goes at first position, contrary
+    to pytorch common rule (but actually it doesn't matter).
+    IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate
+    """
+    flops = 0
+    # input is a tuple containing a sequence to process and (optionally) hidden state
+    inp = input[0]
+    batch_size = inp.shape[0]
+    seq_length = inp.shape[1]
+    num_layers = rnn_module.num_layers
+
+    for i in range(num_layers):
+        w_ih = rnn_module.__getattr__('weight_ih_l' + str(i))
+        w_hh = rnn_module.__getattr__('weight_hh_l' + str(i))
+        if i == 0:
+            input_size = rnn_module.input_size
+        else:
+            input_size = rnn_module.hidden_size
+        flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
+        if rnn_module.bias:
+            b_ih = rnn_module.__getattr__('bias_ih_l' + str(i))
+            b_hh = rnn_module.__getattr__('bias_hh_l' + str(i))
+            flops += b_ih.shape[0] + b_hh.shape[0]
+
+    flops *= batch_size
+    flops *= seq_length
+    if rnn_module.bidirectional:
+        flops *= 2
+    rnn_module.__flops__ += int(flops)
+
+
+def rnn_cell_flops_counter_hook(rnn_cell_module, input, output):
+    flops = 0
+    inp = input[0]
+    batch_size = inp.shape[0]
+    w_ih = rnn_cell_module.__getattr__('weight_ih')
+    w_hh = rnn_cell_module.__getattr__('weight_hh')
+    input_size = inp.shape[1]
+    flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
+    if rnn_cell_module.bias:
+        b_ih = rnn_cell_module.__getattr__('bias_ih')
+        b_hh = rnn_cell_module.__getattr__('bias_hh')
+        flops += b_ih.shape[0] + b_hh.shape[0]
+
+    flops *= batch_size
+    rnn_cell_module.__flops__ += int(flops)
+
+
+def add_batch_counter_variables_or_reset(module):
+
+    module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+    if hasattr(module, '__batch_counter_handle__'):
+        return
+
+    handle = module.register_forward_hook(batch_counter_hook)
+    module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+    if hasattr(module, '__batch_counter_handle__'):
+        module.__batch_counter_handle__.remove()
+        del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+    if is_supported_instance(module):
+        if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+            print('Warning: variables __flops__ or __params__ are already '
+                  'defined for the module' + type(module).__name__ +
+                  ' ptflops can affect your code!')
+        module.__flops__ = 0
+        module.__params__ = get_model_parameters_number(module)
+
+
+CUSTOM_MODULES_MAPPING = {}
+
+MODULES_MAPPING = {
+    # convolutions
+    nn.Conv1d: conv_flops_counter_hook,
+    nn.Conv2d: conv_flops_counter_hook,
+    nn.Conv3d: conv_flops_counter_hook,
+    Conv2d: conv_flops_counter_hook,
+    ModulatedDeformConv: conv_flops_counter_hook,
+    # activations
+    nn.ReLU: relu_flops_counter_hook,
+    nn.PReLU: relu_flops_counter_hook,
+    nn.ELU: relu_flops_counter_hook,
+    nn.LeakyReLU: relu_flops_counter_hook,
+    nn.ReLU6: relu_flops_counter_hook,
+    # poolings
+    nn.MaxPool1d: pool_flops_counter_hook,
+    nn.AvgPool1d: pool_flops_counter_hook,
+    nn.AvgPool2d: pool_flops_counter_hook,
+    nn.MaxPool2d: pool_flops_counter_hook,
+    nn.MaxPool3d: pool_flops_counter_hook,
+    nn.AvgPool3d: pool_flops_counter_hook,
+    nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+    nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+    nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+    nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+    nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+    nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+    # BNs
+    nn.BatchNorm1d: bn_flops_counter_hook,
+    nn.BatchNorm2d: bn_flops_counter_hook,
+    nn.BatchNorm3d: bn_flops_counter_hook,
+    nn.GroupNorm : bn_flops_counter_hook,
+    # FC
+    nn.Linear: linear_flops_counter_hook,
+    # Upscale
+    nn.Upsample: upsample_flops_counter_hook,
+    # Deconvolution
+    nn.ConvTranspose1d: conv_flops_counter_hook,
+    nn.ConvTranspose2d: conv_flops_counter_hook,
+    nn.ConvTranspose3d: conv_flops_counter_hook,
+    ConvTranspose2d: conv_flops_counter_hook,
+    # RNN
+    nn.RNN: rnn_flops_counter_hook,
+    nn.GRU: rnn_flops_counter_hook,
+    nn.LSTM: rnn_flops_counter_hook,
+    nn.RNNCell: rnn_cell_flops_counter_hook,
+    nn.LSTMCell: rnn_cell_flops_counter_hook,
+    nn.GRUCell: rnn_cell_flops_counter_hook
+}
+
+
+def is_supported_instance(module):
+    if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING \
+            or getattr(module, 'compute_macs', False):
+        return True
+    return False
+
+
+def remove_flops_counter_hook_function(module):
+    if is_supported_instance(module):
+        if hasattr(module, '__flops_handle__'):
+            module.__flops_handle__.remove()
+            del module.__flops_handle__
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/blip.py b/models/blip.py
new file mode 100644
index 0000000000000000000000000000000000000000..38678f65ea2c276b351c2c97d429ebc2525ddcf7
--- /dev/null
+++ b/models/blip.py
@@ -0,0 +1,238 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import warnings
+warnings.filterwarnings("ignore")
+
+from models.vit import VisionTransformer, interpolate_pos_embed
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import os
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+
+class BLIP_Base(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 224,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                 
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+        self.tokenizer = init_tokenizer()   
+        med_config = BertConfig.from_json_file(med_config)
+        med_config.encoder_width = vision_width
+        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)  
+
+        
+    def forward(self, image, caption, mode):
+        
+        assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
+        text = self.tokenizer(caption, return_tensors="pt").to(image.device) 
+        
+        if mode=='image':    
+            # return image features
+            image_embeds = self.visual_encoder(image)             
+            return image_embeds
+        
+        elif mode=='text':
+            # return text features
+            text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
+                                            return_dict = True, mode = 'text')  
+            return text_output.last_hidden_state
+        
+        elif mode=='multimodal':
+            # return multimodel features
+            image_embeds = self.visual_encoder(image)    
+            image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)      
+            
+            text.input_ids[:,0] = self.tokenizer.enc_token_id
+            output = self.text_encoder(text.input_ids,
+                                       attention_mask = text.attention_mask,
+                                       encoder_hidden_states = image_embeds,
+                                       encoder_attention_mask = image_atts,      
+                                       return_dict = True,
+                                      )              
+            return output.last_hidden_state
+        
+        
+        
+class BLIP_Decoder(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 384,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,
+                 prompt = 'a picture of ',
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """            
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+        self.tokenizer = init_tokenizer()   
+        med_config = BertConfig.from_json_file(med_config)
+        med_config.encoder_width = vision_width
+        self.text_decoder = BertLMHeadModel(config=med_config)    
+        
+        self.prompt = prompt
+        self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
+
+        
+    def forward(self, image, caption):
+        
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+        
+        text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 
+        
+        text.input_ids[:,0] = self.tokenizer.bos_token_id
+        
+        decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)         
+        decoder_targets[:,:self.prompt_length] = -100
+     
+        decoder_output = self.text_decoder(text.input_ids, 
+                                           attention_mask = text.attention_mask, 
+                                           encoder_hidden_states = image_embeds,
+                                           encoder_attention_mask = image_atts,                  
+                                           labels = decoder_targets,
+                                           return_dict = True,   
+                                          )   
+        loss_lm = decoder_output.loss
+        
+        return loss_lm
+        
+    def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
+        image_embeds = self.visual_encoder(image)
+
+        if not sample:
+            image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
+            
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+        model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
+        
+        prompt = [self.prompt] * image.size(0)
+        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 
+        input_ids[:,0] = self.tokenizer.bos_token_id
+        input_ids = input_ids[:, :-1] 
+
+        if sample:
+            #nucleus sampling
+            outputs = self.text_decoder.generate(input_ids=input_ids,
+                                                  max_length=max_length,
+                                                  min_length=min_length,
+                                                  do_sample=True,
+                                                  top_p=top_p,
+                                                  num_return_sequences=1,
+                                                  eos_token_id=self.tokenizer.sep_token_id,
+                                                  pad_token_id=self.tokenizer.pad_token_id, 
+                                                  repetition_penalty=1.1,                                            
+                                                  **model_kwargs)
+        else:
+            #beam search
+            outputs = self.text_decoder.generate(input_ids=input_ids,
+                                                  max_length=max_length,
+                                                  min_length=min_length,
+                                                  num_beams=num_beams,
+                                                  eos_token_id=self.tokenizer.sep_token_id,
+                                                  pad_token_id=self.tokenizer.pad_token_id,     
+                                                  repetition_penalty=repetition_penalty,
+                                                  **model_kwargs)            
+            
+        captions = []    
+        for output in outputs:
+            caption = self.tokenizer.decode(output, skip_special_tokens=True)    
+            captions.append(caption[len(self.prompt):])
+        return captions
+    
+
+def blip_decoder(pretrained='',**kwargs):
+    model = BLIP_Decoder(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+        assert(len(msg.missing_keys)==0)
+    return model    
+    
+def blip_feature_extractor(pretrained='',**kwargs):
+    model = BLIP_Base(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+        assert(len(msg.missing_keys)==0)
+    return model        
+
+def init_tokenizer():
+    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
+    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
+    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
+    return tokenizer
+
+
+def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
+        
+    assert vit in ['base', 'large'], "vit parameter must be base or large"
+    if vit=='base':
+        vision_width = 768
+        visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 
+                                           num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+                                           drop_path_rate=0 or drop_path_rate
+                                          )   
+    elif vit=='large':
+        vision_width = 1024
+        visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 
+                                           num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+                                           drop_path_rate=0.1 or drop_path_rate
+                                          )   
+    return visual_encoder, vision_width
+
+def is_url(url_or_filename):
+    parsed = urlparse(url_or_filename)
+    return parsed.scheme in ("http", "https")
+
+def load_checkpoint(model,url_or_filename):
+    if is_url(url_or_filename):
+        cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+        checkpoint = torch.load(cached_file, map_location='cpu') 
+    elif os.path.isfile(url_or_filename):        
+        checkpoint = torch.load(url_or_filename, map_location='cpu') 
+    else:
+        raise RuntimeError('checkpoint url or path is invalid')
+        
+    state_dict = checkpoint['model']
+    
+    state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 
+    if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+        state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
+                                                                         model.visual_encoder_m)    
+    for key in model.state_dict().keys():
+        if key in state_dict.keys():
+            if state_dict[key].shape!=model.state_dict()[key].shape:
+                del state_dict[key]
+    
+    msg = model.load_state_dict(state_dict,strict=False)
+    print('load checkpoint from %s'%url_or_filename)  
+    return model,msg
+    
diff --git a/models/blip_itm.py b/models/blip_itm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf354c829564bf5a1f56089a2d745093d51e0fa2
--- /dev/null
+++ b/models/blip_itm.py
@@ -0,0 +1,76 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_ITM(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 384,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                      
+                 embed_dim = 256,     
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+        self.tokenizer = init_tokenizer()   
+        med_config = BertConfig.from_json_file(med_config)
+        med_config.encoder_width = vision_width
+        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)          
+
+        text_width = self.text_encoder.config.hidden_size
+        
+        self.vision_proj = nn.Linear(vision_width, embed_dim)
+        self.text_proj = nn.Linear(text_width, embed_dim)
+
+        self.itm_head = nn.Linear(text_width, 2) 
+        
+        
+    def forward(self, image, caption, match_head='itm'):
+
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)        
+      
+        text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 
+                              return_tensors="pt").to(image.device) 
+
+                 
+        if match_head=='itm':
+            output = self.text_encoder(text.input_ids,
+                                       attention_mask = text.attention_mask,
+                                       encoder_hidden_states = image_embeds,
+                                       encoder_attention_mask = image_atts,      
+                                       return_dict = True,
+                                      )
+            itm_output = self.itm_head(output.last_hidden_state[:,0,:])     
+            return itm_output
+            
+        elif match_head=='itc':
+            text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
+                                            return_dict = True, mode = 'text')                     
+            image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)   
+            text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)    
+            
+            sim = image_feat @ text_feat.t()
+            return sim
+        
+        
+def blip_itm(pretrained='',**kwargs):
+    model = BLIP_ITM(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+        assert(len(msg.missing_keys)==0)
+    return model         
+            
\ No newline at end of file
diff --git a/models/blip_nlvr.py b/models/blip_nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84837167bfa6874d3c3e41fb9b37271113910b7f
--- /dev/null
+++ b/models/blip_nlvr.py
@@ -0,0 +1,103 @@
+from models.med import BertConfig
+from models.nlvr_encoder import BertModel
+from models.vit import interpolate_pos_embed
+from models.blip import create_vit, init_tokenizer, is_url
+
+from timm.models.hub import download_cached_file
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_NLVR(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 480,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                   
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+        self.tokenizer = init_tokenizer()   
+        med_config = BertConfig.from_json_file(med_config)
+        med_config.encoder_width = vision_width
+        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 
+                    
+        self.cls_head = nn.Sequential(
+                  nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
+                  nn.ReLU(),
+                  nn.Linear(self.text_encoder.config.hidden_size, 2)
+                )  
+
+    def forward(self, image, text, targets, train=True):
+        
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)        
+        image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))     
+
+        text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 
+        text.input_ids[:,0] = self.tokenizer.enc_token_id        
+
+        output = self.text_encoder(text.input_ids, 
+                                   attention_mask = text.attention_mask, 
+                                   encoder_hidden_states = [image0_embeds,image1_embeds],
+                                   encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
+                                                             image_atts[image0_embeds.size(0):]],        
+                                   return_dict = True,
+                                  )  
+        hidden_state = output.last_hidden_state[:,0,:]        
+        prediction = self.cls_head(hidden_state)
+
+        if train:            
+            loss = F.cross_entropy(prediction, targets)   
+            return loss
+        else:
+            return prediction
+    
+def blip_nlvr(pretrained='',**kwargs):
+    model = BLIP_NLVR(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+        print("missing keys:")
+        print(msg.missing_keys)
+    return model  
+
+        
+def load_checkpoint(model,url_or_filename):
+    if is_url(url_or_filename):
+        cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+        checkpoint = torch.load(cached_file, map_location='cpu') 
+    elif os.path.isfile(url_or_filename):        
+        checkpoint = torch.load(url_or_filename, map_location='cpu') 
+    else:
+        raise RuntimeError('checkpoint url or path is invalid')
+    state_dict = checkpoint['model']
+    
+    state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 
+    
+    for key in list(state_dict.keys()):
+        if 'crossattention.self.' in key:
+            new_key0 = key.replace('self','self0')
+            new_key1 = key.replace('self','self1')
+            state_dict[new_key0] = state_dict[key]
+            state_dict[new_key1] = state_dict[key]
+        elif 'crossattention.output.dense.' in key:
+            new_key0 = key.replace('dense','dense0')
+            new_key1 = key.replace('dense','dense1')
+            state_dict[new_key0] = state_dict[key]
+            state_dict[new_key1] = state_dict[key]  
+                
+    msg = model.load_state_dict(state_dict,strict=False)
+    print('load checkpoint from %s'%url_or_filename)  
+    return model,msg
+            
\ No newline at end of file
diff --git a/models/blip_pretrain.py b/models/blip_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42ce5f998b0a51e6f731ee6b5c8bae6d02a8664
--- /dev/null
+++ b/models/blip_pretrain.py
@@ -0,0 +1,339 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+import transformers
+transformers.logging.set_verbosity_error()
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Pretrain(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/bert_config.json',  
+                 image_size = 224,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                    
+                 embed_dim = 256,     
+                 queue_size = 57600,
+                 momentum = 0.995,
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
+        
+        if vit=='base':
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+                map_location="cpu", check_hash=True)
+            state_dict = checkpoint["model"]     
+            msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
+        elif vit=='large':
+            from timm.models.helpers import load_custom_pretrained
+            from timm.models.vision_transformer import default_cfgs
+            load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])        
+               
+        self.tokenizer = init_tokenizer()   
+        encoder_config = BertConfig.from_json_file(med_config)
+        encoder_config.encoder_width = vision_width
+        self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
+        self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 
+
+        text_width = self.text_encoder.config.hidden_size
+        
+        self.vision_proj = nn.Linear(vision_width, embed_dim)
+        self.text_proj = nn.Linear(text_width, embed_dim)
+
+        self.itm_head = nn.Linear(text_width, 2) 
+        
+        # create momentum encoders  
+        self.visual_encoder_m, vision_width = create_vit(vit,image_size)              
+        self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+        self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)      
+        self.text_proj_m = nn.Linear(text_width, embed_dim)
+        
+        self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+                            [self.vision_proj,self.vision_proj_m],
+                            [self.text_encoder,self.text_encoder_m],
+                            [self.text_proj,self.text_proj_m],
+                           ]       
+        self.copy_params()
+
+        # create the queue
+        self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+        self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))  
+
+        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+        
+        self.queue_size = queue_size
+        self.momentum = momentum
+        self.temp = nn.Parameter(0.07*torch.ones([]))   
+        
+        # create the decoder
+        decoder_config = BertConfig.from_json_file(med_config)
+        decoder_config.encoder_width = vision_width        
+        self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)    
+        self.text_decoder.resize_token_embeddings(len(self.tokenizer)) 
+        tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
+        
+        
+    def forward(self, image, caption, alpha):
+        with torch.no_grad():
+            self.temp.clamp_(0.001,0.5)
+        
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)        
+        image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)          
+        
+        text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, 
+                              return_tensors="pt").to(image.device)  
+        text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
+                                        return_dict = True, mode = 'text')            
+        text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)                 
+             
+        # get momentum features
+        with torch.no_grad():
+            self._momentum_update()
+            image_embeds_m = self.visual_encoder_m(image) 
+            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
+            image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                   
+            
+            text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,                      
+                                                return_dict = True, mode = 'text')    
+            text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
+            text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+            sim_i2t_m = image_feat_m @ text_feat_all / self.temp  
+            sim_t2i_m = text_feat_m @ image_feat_all / self.temp 
+
+            sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
+            sim_targets.fill_diagonal_(1)          
+
+            sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+            sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets        
+
+        sim_i2t = image_feat @ text_feat_all / self.temp
+        sim_t2i = text_feat @ image_feat_all / self.temp
+                             
+        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
+
+        loss_ita = (loss_i2t+loss_t2i)/2
+
+        self._dequeue_and_enqueue(image_feat_m, text_feat_m)        
+
+        ###============== Image-text Matching ===================###
+        encoder_input_ids = text.input_ids.clone()
+        encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+        
+        # forward the positve image-text pair
+        bs = image.size(0)
+        output_pos = self.text_encoder(encoder_input_ids,
+                                       attention_mask = text.attention_mask,
+                                       encoder_hidden_states = image_embeds,
+                                       encoder_attention_mask = image_atts,      
+                                       return_dict = True,
+                                      )            
+        with torch.no_grad():       
+            weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 
+            weights_t2i.fill_diagonal_(0)            
+            weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4  
+            weights_i2t.fill_diagonal_(0)   
+            
+        # select a negative image for each text
+        image_embeds_neg = []    
+        for b in range(bs):
+            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+            image_embeds_neg.append(image_embeds[neg_idx])
+        image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   
+
+        # select a negative text for each image
+        text_ids_neg = []
+        text_atts_neg = []
+        for b in range(bs):
+            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+            text_ids_neg.append(encoder_input_ids[neg_idx])
+            text_atts_neg.append(text.attention_mask[neg_idx])
+
+        text_ids_neg = torch.stack(text_ids_neg,dim=0)   
+        text_atts_neg = torch.stack(text_atts_neg,dim=0)      
+
+        text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)     
+        text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)     
+
+        image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+        image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+        output_neg = self.text_encoder(text_ids_all,
+                                       attention_mask = text_atts_all,
+                                       encoder_hidden_states = image_embeds_all,
+                                       encoder_attention_mask = image_atts_all,      
+                                       return_dict = True,
+                                      )                            
+
+        vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+        vl_output = self.itm_head(vl_embeddings)            
+
+        itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+                               dim=0).to(image.device)
+        loss_itm = F.cross_entropy(vl_output, itm_labels)  
+        
+        ##================= LM ========================##     
+        decoder_input_ids = text.input_ids.clone()      
+        decoder_input_ids[:,0] = self.tokenizer.bos_token_id
+        decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) 
+
+        decoder_output = self.text_decoder(decoder_input_ids, 
+                                           attention_mask = text.attention_mask, 
+                                           encoder_hidden_states = image_embeds,
+                                           encoder_attention_mask = image_atts,                  
+                                           labels = decoder_targets,
+                                           return_dict = True,   
+                                          )   
+          
+        loss_lm = decoder_output.loss                
+        return loss_ita, loss_itm, loss_lm
+ 
+
+
+    @torch.no_grad()    
+    def copy_params(self):
+        for model_pair in self.model_pairs:           
+            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+                param_m.data.copy_(param.data)  # initialize
+                param_m.requires_grad = False  # not update by gradient    
+
+            
+    @torch.no_grad()        
+    def _momentum_update(self):
+        for model_pair in self.model_pairs:           
+            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+
+                        
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self, image_feat, text_feat):
+        # gather keys before updating queue
+        image_feats = concat_all_gather(image_feat)
+        text_feats = concat_all_gather(text_feat)
+
+        batch_size = image_feats.shape[0]
+
+        ptr = int(self.queue_ptr)
+        assert self.queue_size % batch_size == 0  # for simplicity
+
+        # replace the keys at ptr (dequeue and enqueue)
+        self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+        self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+        ptr = (ptr + batch_size) % self.queue_size  # move pointer
+
+        self.queue_ptr[0] = ptr 
+
+
+def blip_pretrain(**kwargs):
+    model = BLIP_Pretrain(**kwargs)
+    return model 
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [torch.ones_like(tensor)
+        for _ in range(torch.distributed.get_world_size())]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output     
+
+
+from typing import List
+def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
+    uninitialized_encoder_weights: List[str] = []
+    if decoder.__class__ != encoder.__class__:
+        logger.info(
+            f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
+        )
+
+    def tie_encoder_to_decoder_recursively(
+        decoder_pointer: nn.Module,
+        encoder_pointer: nn.Module,
+        module_name: str,
+        uninitialized_encoder_weights: List[str],
+        skip_key: str,
+        depth=0,
+    ):
+        assert isinstance(decoder_pointer, nn.Module) and isinstance(
+            encoder_pointer, nn.Module
+        ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
+        if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
+            assert hasattr(encoder_pointer, "weight")
+            encoder_pointer.weight = decoder_pointer.weight
+            if hasattr(decoder_pointer, "bias"):
+                assert hasattr(encoder_pointer, "bias")
+                encoder_pointer.bias = decoder_pointer.bias                
+            print(module_name+' is tied')    
+            return
+
+        encoder_modules = encoder_pointer._modules
+        decoder_modules = decoder_pointer._modules
+        if len(decoder_modules) > 0:
+            assert (
+                len(encoder_modules) > 0
+            ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+            all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
+            encoder_layer_pos = 0
+            for name, module in decoder_modules.items():
+                if name.isdigit():
+                    encoder_name = str(int(name) + encoder_layer_pos)
+                    decoder_name = name
+                    if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
+                        encoder_modules
+                    ) != len(decoder_modules):
+                        # this can happen if the name corresponds to the position in a list module list of layers
+                        # in this case the decoder has added a cross-attention that the encoder does not have
+                        # thus skip this step and subtract one layer pos from encoder
+                        encoder_layer_pos -= 1
+                        continue
+                elif name not in encoder_modules:
+                    continue
+                elif depth > 500:
+                    raise ValueError(
+                        "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
+                    )
+                else:
+                    decoder_name = encoder_name = name
+                tie_encoder_to_decoder_recursively(
+                    decoder_modules[decoder_name],
+                    encoder_modules[encoder_name],
+                    module_name + "/" + name,
+                    uninitialized_encoder_weights,
+                    skip_key,
+                    depth=depth + 1,
+                )
+                all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+            uninitialized_encoder_weights += list(all_encoder_weights)
+
+    # tie weights recursively
+    tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)  
diff --git a/models/blip_retrieval.py b/models/blip_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1debe7e2e664f8dd603f8d4c537e3599c68638d7
--- /dev/null
+++ b/models/blip_retrieval.py
@@ -0,0 +1,319 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Retrieval(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 384,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                      
+                 embed_dim = 256,     
+                 queue_size = 57600,
+                 momentum = 0.995,
+                 negative_all_rank = False,
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+        self.tokenizer = init_tokenizer()   
+        med_config = BertConfig.from_json_file(med_config)
+        med_config.encoder_width = vision_width
+        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)          
+
+        text_width = self.text_encoder.config.hidden_size
+        
+        self.vision_proj = nn.Linear(vision_width, embed_dim)
+        self.text_proj = nn.Linear(text_width, embed_dim)
+
+        self.itm_head = nn.Linear(text_width, 2) 
+        
+        # create momentum encoders  
+        self.visual_encoder_m, vision_width = create_vit(vit,image_size)              
+        self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+        self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)    
+        self.text_proj_m = nn.Linear(text_width, embed_dim)
+        
+        self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+                            [self.vision_proj,self.vision_proj_m],
+                            [self.text_encoder,self.text_encoder_m],
+                            [self.text_proj,self.text_proj_m],
+                           ]       
+        self.copy_params()
+
+        # create the queue
+        self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+        self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+        self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
+        self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))  
+
+        self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+        self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+        
+        self.queue_size = queue_size
+        self.momentum = momentum
+        self.temp = nn.Parameter(0.07*torch.ones([]))   
+        
+        self.negative_all_rank = negative_all_rank
+        
+        
+    def forward(self, image, caption, alpha, idx):
+        with torch.no_grad():
+            self.temp.clamp_(0.001,0.5)
+        
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)        
+        image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)    
+        
+        text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 
+                              return_tensors="pt").to(image.device) 
+        
+        text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
+                                        return_dict = True, mode = 'text')            
+        text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)        
+        
+        ###============== Image-text Contrastive Learning ===================###
+        idx = idx.view(-1,1)
+        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
+        pos_idx = torch.eq(idx, idx_all).float()       
+        sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)   
+        
+        # get momentum features
+        with torch.no_grad():
+            self._momentum_update()
+            image_embeds_m = self.visual_encoder_m(image) 
+            image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
+            image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                   
+            
+            text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,                      
+                                                return_dict = True, mode = 'text')    
+            text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
+            text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+            sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp  
+            sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp   
+
+            sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+            sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets        
+
+        sim_i2t = image_feat @ text_feat_m_all / self.temp 
+        sim_t2i = text_feat @ image_feat_m_all / self.temp 
+                             
+        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
+
+        loss_ita = (loss_i2t+loss_t2i)/2
+        
+        idxs = concat_all_gather(idx)
+        self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)        
+
+        ###============== Image-text Matching ===================###
+        encoder_input_ids = text.input_ids.clone()
+        encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+
+        # forward the positve image-text pair
+        bs = image.size(0)
+        output_pos = self.text_encoder(encoder_input_ids,
+                                       attention_mask = text.attention_mask,
+                                       encoder_hidden_states = image_embeds,
+                                       encoder_attention_mask = image_atts,      
+                                       return_dict = True,
+                                      )  
+        
+        
+        if self.negative_all_rank:    
+            # compute sample similarity
+            with torch.no_grad():                
+                mask = torch.eq(idx, idxs.t())
+
+                image_feat_world = concat_all_gather(image_feat)
+                text_feat_world = concat_all_gather(text_feat)
+
+                sim_i2t = image_feat @ text_feat_world.t() / self.temp 
+                sim_t2i = text_feat @ image_feat_world.t() / self.temp 
+
+                weights_i2t = F.softmax(sim_i2t,dim=1)
+                weights_i2t.masked_fill_(mask, 0)            
+
+                weights_t2i = F.softmax(sim_t2i,dim=1)
+                weights_t2i.masked_fill_(mask, 0)     
+
+            image_embeds_world = all_gather_with_grad(image_embeds) 
+
+            # select a negative image (from all ranks) for each text
+            image_embeds_neg = []    
+            for b in range(bs):
+                neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+                image_embeds_neg.append(image_embeds_world[neg_idx])
+            image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   
+
+            # select a negative text (from all ranks) for each image
+            input_ids_world = concat_all_gather(encoder_input_ids)
+            att_mask_world = concat_all_gather(text.attention_mask)        
+
+            text_ids_neg = []
+            text_atts_neg = []
+            for b in range(bs):
+                neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+                text_ids_neg.append(input_ids_world[neg_idx])
+                text_atts_neg.append(att_mask_world[neg_idx])
+                
+        else:
+            with torch.no_grad():                
+                mask = torch.eq(idx, idx.t())
+                
+                sim_i2t = image_feat @ text_feat.t() / self.temp 
+                sim_t2i = text_feat @ image_feat.t() / self.temp 
+
+                weights_i2t = F.softmax(sim_i2t,dim=1)
+                weights_i2t.masked_fill_(mask, 0)            
+
+                weights_t2i = F.softmax(sim_t2i,dim=1)
+                weights_t2i.masked_fill_(mask, 0)     
+
+            # select a negative image (from same rank) for each text
+            image_embeds_neg = []    
+            for b in range(bs):
+                neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+                image_embeds_neg.append(image_embeds[neg_idx])
+            image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   
+
+            # select a negative text (from same rank) for each image    
+            text_ids_neg = []
+            text_atts_neg = []
+            for b in range(bs):
+                neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+                text_ids_neg.append(encoder_input_ids[neg_idx])
+                text_atts_neg.append(text.attention_mask[neg_idx])            
+            
+        text_ids_neg = torch.stack(text_ids_neg,dim=0)   
+        text_atts_neg = torch.stack(text_atts_neg,dim=0)      
+
+        text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)     
+        text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)     
+
+        image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+        image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+        output_neg = self.text_encoder(text_ids_all,
+                                       attention_mask = text_atts_all,
+                                       encoder_hidden_states = image_embeds_all,
+                                       encoder_attention_mask = image_atts_all,      
+                                       return_dict = True,
+                                      )                         
+          
+
+        vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+        vl_output = self.itm_head(vl_embeddings)            
+
+        itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+                               dim=0).to(image.device)
+        loss_itm = F.cross_entropy(vl_output, itm_labels)     
+
+        return loss_ita, loss_itm 
+ 
+
+    @torch.no_grad()    
+    def copy_params(self):
+        for model_pair in self.model_pairs:           
+            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+                param_m.data.copy_(param.data)  # initialize
+                param_m.requires_grad = False  # not update by gradient    
+
+            
+    @torch.no_grad()        
+    def _momentum_update(self):
+        for model_pair in self.model_pairs:           
+            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+                
+                
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
+        # gather keys before updating queue
+        image_feats = concat_all_gather(image_feat)
+        text_feats = concat_all_gather(text_feat)
+        
+
+        batch_size = image_feats.shape[0]
+
+        ptr = int(self.ptr_queue)
+        assert self.queue_size % batch_size == 0  # for simplicity
+
+        # replace the keys at ptr (dequeue and enqueue)
+        self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+        self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+        self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
+        ptr = (ptr + batch_size) % self.queue_size # move pointer
+
+        self.ptr_queue[0] = ptr  
+
+
+def blip_retrieval(pretrained='',**kwargs):
+    model = BLIP_Retrieval(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+        print("missing keys:")
+        print(msg.missing_keys)
+    return model 
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [torch.ones_like(tensor)
+        for _ in range(torch.distributed.get_world_size())]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output      
+
+
+class GatherLayer(torch.autograd.Function):
+    """
+    Gather tensors from all workers with support for backward propagation:
+    This implementation does not cut the gradients as torch.distributed.all_gather does.
+    """
+
+    @staticmethod
+    def forward(ctx, x):
+        output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
+        torch.distributed.all_gather(output, x)
+        return tuple(output)
+
+    @staticmethod
+    def backward(ctx, *grads):
+        all_gradients = torch.stack(grads)
+        torch.distributed.all_reduce(all_gradients)
+        return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+    """
+    Performs all_gather operation on the provided tensors.
+    Graph remains connected for backward grad computation.
+    """
+    # Queue the gathered tensors
+    world_size = torch.distributed.get_world_size()
+    # There is no need for reduction in the single-proc case
+    if world_size == 1:
+        return tensors
+
+    tensor_all = GatherLayer.apply(tensors)
+
+    return torch.cat(tensor_all, dim=0)
diff --git a/models/blip_vqa.py b/models/blip_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ef8641cab1badd32e00abea352d764f6165faae
--- /dev/null
+++ b/models/blip_vqa.py
@@ -0,0 +1,223 @@
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_VQA(nn.Module):
+    def __init__(self,                 
+                 med_config = 'configs/med_config.json',  
+                 image_size = 480,
+                 vit = 'base',
+                 vit_grad_ckpt = False,
+                 vit_ckpt_layer = 0,                   
+                 ):
+        """
+        Args:
+            med_config (str): path for the mixture of encoder-decoder model's configuration file
+            image_size (int): input image size
+            vit (str): model size of vision transformer
+        """               
+        super().__init__()
+        
+        self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+        self.tokenizer = init_tokenizer()  
+        
+        encoder_config = BertConfig.from_json_file(med_config)
+        encoder_config.encoder_width = vision_width
+        self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 
+        
+        decoder_config = BertConfig.from_json_file(med_config)        
+        self.text_decoder = BertLMHeadModel(config=decoder_config)          
+
+        self.itm_head = nn.Linear(768, 2)
+
+    def forward(self, image, question, answer=None, n=None, weights=None, mode='inference', inference='rank', k_test=128):
+        
+        image_embeds = self.visual_encoder(image) 
+        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+        
+        question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 
+                                  return_tensors="pt").to(image.device) 
+        question.input_ids[:,0] = self.tokenizer.enc_token_id
+        
+        if mode == 'train':               
+            '''
+            n: number of answers for each question
+            weights: weight for each answer
+            '''                     
+            answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 
+            answer.input_ids[:,0] = self.tokenizer.bos_token_id
+            answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)      
+
+            question_output = self.text_encoder(question.input_ids, 
+                                                attention_mask = question.attention_mask, 
+                                                encoder_hidden_states = image_embeds,
+                                                encoder_attention_mask = image_atts,                             
+                                                return_dict = True)    
+
+            question_states = []                
+            question_atts = []  
+            for b, n in enumerate(n):
+                question_states += [question_output.last_hidden_state[b]]*n
+                question_atts += [question.attention_mask[b]]*n                
+            question_states = torch.stack(question_states,0)    
+            question_atts = torch.stack(question_atts,0)     
+
+            answer_output = self.text_decoder(answer.input_ids, 
+                                              attention_mask = answer.attention_mask, 
+                                              encoder_hidden_states = question_states,
+                                              encoder_attention_mask = question_atts,                  
+                                              labels = answer_targets,
+                                              return_dict = True,   
+                                              reduction = 'none',
+                                             )      
+            
+            loss = weights * answer_output.loss
+            loss = loss.sum()/image.size(0)
+
+            return loss
+            
+        elif mode == 'gradcam':
+            question_output = self.text_encoder(question.input_ids, 
+                                                attention_mask = question.attention_mask, 
+                                                encoder_hidden_states = image_embeds,
+                                                encoder_attention_mask = image_atts,                                    
+                                                return_dict = True)
+            
+            vl_embeddings = question_output.last_hidden_state[:,0,:]
+            vl_output = self.itm_head(vl_embeddings)   
+            
+            if inference=='generate':
+                num_beams = 3
+                question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+                question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
+                model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
+                
+                bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
+                
+                outputs = self.text_decoder.generate(input_ids=bos_ids,
+                                                     max_length=10,
+                                                     min_length=1,
+                                                     num_beams=num_beams,
+                                                     eos_token_id=self.tokenizer.sep_token_id,
+                                                     pad_token_id=self.tokenizer.pad_token_id, 
+                                                     **model_kwargs)
+                
+                answers = []    
+                for output in outputs:
+                    answer = self.tokenizer.decode(output, skip_special_tokens=True)    
+                    answers.append(answer)
+                return answers, vl_output, question
+            
+            elif inference=='rank':
+                max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 
+                                           answer.input_ids, answer.attention_mask, k_test) 
+                return max_ids, vl_output, question
+
+        else: 
+            question_output = self.text_encoder(question.input_ids, 
+                                                attention_mask = question.attention_mask, 
+                                                encoder_hidden_states = image_embeds,
+                                                encoder_attention_mask = image_atts,                                    
+                                                return_dict = True) 
+            
+            if inference=='generate':
+                num_beams = 3
+                question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+                question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
+                model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
+                
+                bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
+                
+                outputs = self.text_decoder.generate(input_ids=bos_ids,
+                                                     max_length=10,
+                                                     min_length=1,
+                                                     num_beams=num_beams,
+                                                     eos_token_id=self.tokenizer.sep_token_id,
+                                                     pad_token_id=self.tokenizer.pad_token_id, 
+                                                     **model_kwargs)
+                
+                answers = []    
+                for output in outputs:
+                    answer = self.tokenizer.decode(output, skip_special_tokens=True)    
+                    answers.append(answer)
+                return answers
+            
+            elif inference=='rank':
+                max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 
+                                           answer.input_ids, answer.attention_mask, k_test) 
+                return max_ids
+ 
+                
+                
+    def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
+        
+        num_ques = question_states.size(0)
+        start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
+        
+        start_output = self.text_decoder(start_ids, 
+                                         encoder_hidden_states = question_states,
+                                         encoder_attention_mask = question_atts,                                      
+                                         return_dict = True,
+                                         reduction = 'none')              
+        logits = start_output.logits[:,0,:] # first token's logit
+        
+        # topk_probs: top-k probability 
+        # topk_ids: [num_question, k]        
+        answer_first_token = answer_ids[:,1]
+        prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 
+        topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 
+        
+        # answer input: [num_question*k, answer_len]                 
+        input_ids = []
+        input_atts = []
+        for b, topk_id in enumerate(topk_ids):
+            input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
+            input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
+        input_ids = torch.cat(input_ids,dim=0)  
+        input_atts = torch.cat(input_atts,dim=0)  
+
+        targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
+
+        # repeat encoder's output for top-k answers
+        question_states = tile(question_states, 0, k)
+        question_atts = tile(question_atts, 0, k)
+        
+        output = self.text_decoder(input_ids, 
+                                   attention_mask = input_atts, 
+                                   encoder_hidden_states = question_states,
+                                   encoder_attention_mask = question_atts,     
+                                   labels = targets_ids,
+                                   return_dict = True, 
+                                   reduction = 'none')   
+        
+        log_probs_sum = -output.loss
+        log_probs_sum = log_probs_sum.view(num_ques,k)
+
+        max_topk_ids = log_probs_sum.argmax(dim=1) 
+        max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
+
+        return max_ids
+    
+    
+def blip_vqa(pretrained='',**kwargs):
+    model = BLIP_VQA(**kwargs)
+    if pretrained:
+        model,msg = load_checkpoint(model,pretrained)
+#         assert(len(msg.missing_keys)==0)
+    return model  
+
+
+def tile(x, dim, n_tile):
+    init_dim = x.size(dim)
+    repeat_idx = [1] * x.dim()
+    repeat_idx[dim] = n_tile
+    x = x.repeat(*(repeat_idx))
+    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
+    return torch.index_select(x, dim, order_index.to(x.device))    
+        
+        
\ No newline at end of file
diff --git a/models/med.py b/models/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..99b0abab574a850320cc784aef4cc016f2b174c1
--- /dev/null
+++ b/models/med.py
@@ -0,0 +1,955 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+    ModelOutput,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    NextSentencePredictorOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+    """Construct the embeddings from word and position embeddings."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        
+        self.config = config
+
+    def forward(
+        self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+    ):
+        if input_ids is not None:
+            input_shape = input_ids.size()
+        else:
+            input_shape = inputs_embeds.size()[:-1]
+
+        seq_length = input_shape[1]
+
+        if position_ids is None:
+            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        embeddings = inputs_embeds
+
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, is_cross_attention):
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                "The hidden size (%d) is not a multiple of the number of attention "
+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+            )
+        
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        if is_cross_attention:
+            self.key = nn.Linear(config.encoder_width, self.all_head_size)
+            self.value = nn.Linear(config.encoder_width, self.all_head_size)
+        else:
+            self.key = nn.Linear(config.hidden_size, self.all_head_size)
+            self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+        self.save_attention = False   
+            
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+        
+    def get_attn_gradients(self):
+        return self.attn_gradients
+    
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+        
+    def get_attention_map(self):
+        return self.attention_map
+    
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        past_key_value = (key_layer, value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+        
+        if is_cross_attention and self.save_attention:
+            self.save_attention_map(attention_probs)
+            attention_probs.register_hook(self.save_attn_gradients)     
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs_dropped = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs_dropped = attention_probs_dropped * head_mask
+
+        context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, is_cross_attention=False):
+        super().__init__()
+        self.self = BertSelfAttention(config, is_cross_attention)
+        self.output = BertSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+    def __init__(self, config, layer_num):
+        super().__init__()
+        self.config = config
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = BertAttention(config)      
+        self.layer_num = layer_num          
+        if self.config.add_cross_attention:
+            self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        mode=None,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+        )
+        attention_output = self_attention_outputs[0]
+
+        outputs = self_attention_outputs[1:-1]
+        present_key_value = self_attention_outputs[-1]
+
+        if mode=='multimodal':
+            assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+            cross_attention_outputs = self.crossattention(
+                attention_output,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states,
+                encoder_attention_mask,
+                output_attentions=output_attentions,
+            )
+            attention_output = cross_attention_outputs[0]
+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+
+        outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        mode='multimodal',
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        next_decoder_cache = () if use_cache else None
+               
+        for i in range(self.config.num_hidden_layers):
+            layer_module = self.layer[i]
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                if use_cache:
+                    logger.warn(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs, past_key_value, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    mode=mode,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                    mode=mode,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class BertPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+
+    def forward(self, sequence_output):
+        prediction_scores = self.predictions(sequence_output)
+        return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = BertConfig
+    base_model_prefix = "bert"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def _init_weights(self, module):
+        """ Initialize the weights """
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+    """
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+    input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=True):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = BertEmbeddings(config)
+        
+        self.encoder = BertEncoder(config)
+
+        self.pooler = BertPooler(config) if add_pooling_layer else None
+
+        self.init_weights()
+ 
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    
+    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if is_decoder:
+                batch_size, seq_length = input_shape
+
+                seq_ids = torch.arange(seq_length, device=device)
+                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+                # causal and attention masks must have same type with pytorch version < 1.3
+                causal_mask = causal_mask.to(attention_mask.dtype)
+   
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    causal_mask = torch.cat(
+                        [
+                            torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+                            causal_mask,
+                        ],
+                        axis=-1,
+                    )                     
+
+                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+                    input_shape, attention_mask.shape
+                )
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        return extended_attention_mask
+    
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+        mode='multimodal',
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if is_decoder:
+            use_cache = use_cache if use_cache is not None else self.config.use_cache
+        else:
+            use_cache = False
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            batch_size, seq_length = input_shape
+            device = input_ids.device
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size, seq_length = input_shape
+            device = inputs_embeds.device
+        elif encoder_embeds is not None:    
+            input_shape = encoder_embeds.size()[:-1]
+            batch_size, seq_length = input_shape 
+            device = encoder_embeds.device
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if attention_mask is None:
+            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+            
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 
+                                                                                 device, is_decoder)
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if encoder_hidden_states is not None:
+            if type(encoder_hidden_states) == list:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+            else:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            
+            if type(encoder_attention_mask) == list:
+                encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+            elif encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+            else:    
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+        
+        if encoder_embeds is None:
+            embedding_output = self.embeddings(
+                input_ids=input_ids,
+                position_ids=position_ids,
+                inputs_embeds=inputs_embeds,
+                past_key_values_length=past_key_values_length,
+            )
+        else:
+            embedding_output = encoder_embeds
+            
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            mode=mode,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        return_logits=False,            
+        is_decoder=True,
+        reduction='mean',
+        mode='multimodal', 
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+            >>> config = BertConfig.from_pretrained("bert-base-cased")
+            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+            >>> outputs = model(**inputs)
+            >>> prediction_logits = outputs.logits
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if labels is not None:
+            use_cache = False
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+            mode=mode,
+        )
+        
+        sequence_output = outputs[0]
+        prediction_scores = self.cls(sequence_output)
+        
+        if return_logits:
+            return prediction_scores[:, :-1, :].contiguous()  
+
+        lm_loss = None
+        if labels is not None:
+            # we are doing next-token prediction; shift prediction scores and input ids by one
+            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+            labels = labels[:, 1:].contiguous()
+            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 
+            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+            if reduction=='none':
+                lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)               
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((lm_loss,) + output) if lm_loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=lm_loss,
+            logits=prediction_scores,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+        input_shape = input_ids.shape
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_shape)
+
+        # cut decoder_input_ids if past is used
+        if past is not None:
+            input_ids = input_ids[:, -1:]
+
+        return {
+            "input_ids": input_ids, 
+            "attention_mask": attention_mask, 
+            "past_key_values": past,
+            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+            "is_decoder": True,
+        }
+
+    def _reorder_cache(self, past, beam_idx):
+        reordered_past = ()
+        for layer_past in past:
+            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+        return reordered_past
diff --git a/models/nlvr_encoder.py b/models/nlvr_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1946bb4a300f75afa4848f6622839445903c34a9
--- /dev/null
+++ b/models/nlvr_encoder.py
@@ -0,0 +1,843 @@
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+    ModelOutput,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    NextSentencePredictorOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+    """Construct the embeddings from word and position embeddings."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        
+        self.config = config
+
+    def forward(
+        self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+    ):
+        if input_ids is not None:
+            input_shape = input_ids.size()
+        else:
+            input_shape = inputs_embeds.size()[:-1]
+
+        seq_length = input_shape[1]
+
+        if position_ids is None:
+            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        embeddings = inputs_embeds
+
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, is_cross_attention):
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                "The hidden size (%d) is not a multiple of the number of attention "
+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+            )
+        
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        if is_cross_attention:
+            self.key = nn.Linear(config.encoder_width, self.all_head_size)
+            self.value = nn.Linear(config.encoder_width, self.all_head_size)
+        else:
+            self.key = nn.Linear(config.hidden_size, self.all_head_size)
+            self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+        self.save_attention = False   
+            
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+        
+    def get_attn_gradients(self):
+        return self.attn_gradients
+    
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+        
+    def get_attention_map(self):
+        return self.attention_map
+    
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        past_key_value = (key_layer, value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+        
+        if is_cross_attention and self.save_attention:
+            self.save_attention_map(attention_probs)
+            attention_probs.register_hook(self.save_attn_gradients)         
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs_dropped = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs_dropped = attention_probs_dropped * head_mask
+
+        context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config, twin=False, merge=False):     
+        super().__init__()
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)        
+        if twin:
+            self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
+            self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)         
+        else:
+            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if merge:
+            self.act =  ACT2FN[config.hidden_act]
+            self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
+            self.merge = True
+        else:
+            self.merge = False
+
+    def forward(self, hidden_states, input_tensor):
+        if type(hidden_states) == list:
+            hidden_states0 = self.dense0(hidden_states[0])
+            hidden_states1 = self.dense1(hidden_states[1])        
+            if self.merge:  
+                #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
+                hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
+            else:
+                hidden_states = (hidden_states0+hidden_states1)/2
+        else:    
+            hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, is_cross_attention=False, layer_num=-1):
+        super().__init__()
+        if is_cross_attention:
+            self.self0 = BertSelfAttention(config, is_cross_attention)
+            self.self1 = BertSelfAttention(config, is_cross_attention)
+        else:    
+            self.self = BertSelfAttention(config, is_cross_attention)
+        self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):        
+        if type(encoder_hidden_states)==list:   
+            self_outputs0 = self.self0(
+                hidden_states,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states[0],
+                encoder_attention_mask[0],
+                past_key_value,
+                output_attentions,
+            )
+            self_outputs1 = self.self1(
+                hidden_states,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states[1],
+                encoder_attention_mask[1],
+                past_key_value,
+                output_attentions,
+            )                        
+            attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
+    
+            outputs = (attention_output,) + self_outputs0[1:]  # add attentions if we output them
+        else:        
+            self_outputs = self.self(
+                hidden_states,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states,
+                encoder_attention_mask,
+                past_key_value,
+                output_attentions,
+            )
+            attention_output = self.output(self_outputs[0], hidden_states)
+            outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+    def __init__(self, config, layer_num):
+        super().__init__()
+        self.config = config
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = BertAttention(config)      
+        self.layer_num = layer_num          
+        if self.config.add_cross_attention:
+            self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        mode=None,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+        )
+        attention_output = self_attention_outputs[0]
+
+        outputs = self_attention_outputs[1:-1]
+        present_key_value = self_attention_outputs[-1]
+
+        if mode=='multimodal':
+            assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+            cross_attention_outputs = self.crossattention(
+                attention_output,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states,
+                encoder_attention_mask,
+                output_attentions=output_attentions,
+            )
+            attention_output = cross_attention_outputs[0]
+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+
+        outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        mode='multimodal',
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        next_decoder_cache = () if use_cache else None
+               
+        for i in range(self.config.num_hidden_layers):
+            layer_module = self.layer[i]
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                if use_cache:
+                    logger.warn(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs, past_key_value, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    mode=mode,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                    mode=mode,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class BertPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+
+    def forward(self, sequence_output):
+        prediction_scores = self.predictions(sequence_output)
+        return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = BertConfig
+    base_model_prefix = "bert"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def _init_weights(self, module):
+        """ Initialize the weights """
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+    """
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+    input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=True):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = BertEmbeddings(config)
+        
+        self.encoder = BertEncoder(config)
+
+        self.pooler = BertPooler(config) if add_pooling_layer else None
+
+        self.init_weights()
+ 
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    
+    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if is_decoder:
+                batch_size, seq_length = input_shape
+
+                seq_ids = torch.arange(seq_length, device=device)
+                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+                # causal and attention masks must have same type with pytorch version < 1.3
+                causal_mask = causal_mask.to(attention_mask.dtype)
+   
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    causal_mask = torch.cat(
+                        [
+                            torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+                            causal_mask,
+                        ],
+                        axis=-1,
+                    )                     
+
+                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+                    input_shape, attention_mask.shape
+                )
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        return extended_attention_mask
+    
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+        mode='multimodal',
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if is_decoder:
+            use_cache = use_cache if use_cache is not None else self.config.use_cache
+        else:
+            use_cache = False
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            batch_size, seq_length = input_shape
+            device = input_ids.device
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size, seq_length = input_shape
+            device = inputs_embeds.device
+        elif encoder_embeds is not None:    
+            input_shape = encoder_embeds.size()[:-1]
+            batch_size, seq_length = input_shape 
+            device = encoder_embeds.device
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if attention_mask is None:
+            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+            
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 
+                                                                                 device, is_decoder)
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if encoder_hidden_states is not None:
+            if type(encoder_hidden_states) == list:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+            else:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            
+            if type(encoder_attention_mask) == list:
+                encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+            elif encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+            else:    
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+        
+        if encoder_embeds is None:
+            embedding_output = self.embeddings(
+                input_ids=input_ids,
+                position_ids=position_ids,
+                inputs_embeds=inputs_embeds,
+                past_key_values_length=past_key_values_length,
+            )
+        else:
+            embedding_output = encoder_embeds
+            
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            mode=mode,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
diff --git a/models/vit.py b/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899
--- /dev/null
+++ b/models/vit.py
@@ -0,0 +1,305 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+    """
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.attn_gradients = None
+        self.attention_map = None
+        
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+        
+    def get_attn_gradients(self):
+        return self.attn_gradients
+    
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+        
+    def get_attention_map(self):
+        return self.attention_map
+    
+    def forward(self, x, register_hook=False):
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+                
+        if register_hook:
+            self.save_attention_map(attn)
+            attn.register_hook(self.save_attn_gradients)        
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if use_grad_checkpointing:
+            self.attn = checkpoint_wrapper(self.attn)
+            self.mlp = checkpoint_wrapper(self.mlp)
+
+    def forward(self, x, register_hook=False):
+        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+    
+class VisionTransformer(nn.Module):
+    """ Vision Transformer
+    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
+        https://arxiv.org/abs/2010.11929
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 
+                 use_grad_checkpointing=False, ckpt_layer=0):
+        """
+        Args:
+            img_size (int, tuple): input image size
+            patch_size (int, tuple): patch size
+            in_chans (int): number of input channels
+            num_classes (int): number of classes for classification head
+            embed_dim (int): embedding dimension
+            depth (int): depth of transformer
+            num_heads (int): number of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+            drop_rate (float): dropout rate
+            attn_drop_rate (float): attention dropout rate
+            drop_path_rate (float): stochastic depth rate
+            norm_layer: (nn.Module): normalization layer
+        """
+        super().__init__()
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+            )
+            for i in range(depth)])
+        self.norm = norm_layer(embed_dim)
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def forward(self, x, register_blk=-1):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+
+        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+  
+        x = x + self.pos_embed[:,:x.size(1),:]
+        x = self.pos_drop(x)
+
+        for i,blk in enumerate(self.blocks):
+            x = blk(x, register_blk==i)
+        x = self.norm(x)
+        
+        return x
+
+    @torch.jit.ignore()
+    def load_pretrained(self, checkpoint_path, prefix=''):
+        _load_weights(self, checkpoint_path, prefix)
+        
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+    """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+    """
+    import numpy as np
+
+    def _n2p(w, t=True):
+        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+            w = w.flatten()
+        if t:
+            if w.ndim == 4:
+                w = w.transpose([3, 2, 0, 1])
+            elif w.ndim == 3:
+                w = w.transpose([2, 0, 1])
+            elif w.ndim == 2:
+                w = w.transpose([1, 0])
+        return torch.from_numpy(w)
+
+    w = np.load(checkpoint_path)
+    if not prefix and 'opt/target/embedding/kernel' in w:
+        prefix = 'opt/target/'
+
+    if hasattr(model.patch_embed, 'backbone'):
+        # hybrid
+        backbone = model.patch_embed.backbone
+        stem_only = not hasattr(backbone, 'stem')
+        stem = backbone if stem_only else backbone.stem
+        stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+        stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+        stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+        if not stem_only:
+            for i, stage in enumerate(backbone.stages):
+                for j, block in enumerate(stage.blocks):
+                    bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+                    for r in range(3):
+                        getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+                        getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+                        getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+                    if block.downsample is not None:
+                        block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+                        block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+                        block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+        embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+    else:
+        embed_conv_w = adapt_input_conv(
+            model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+    model.patch_embed.proj.weight.copy_(embed_conv_w)
+    model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+    model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+    pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+    if pos_embed_w.shape != model.pos_embed.shape:
+        pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights
+            pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+    model.pos_embed.copy_(pos_embed_w)
+    model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+    model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+#     if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+#         model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+#         model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+#     if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+#         model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+#         model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+    for i, block in enumerate(model.blocks.children()):
+        block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+        mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+        block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+        block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+        block.attn.qkv.weight.copy_(torch.cat([
+            _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+        block.attn.qkv.bias.copy_(torch.cat([
+            _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+        block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+        block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+        for r in range(2):
+            getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+            getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+        block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+        block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+            
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):        
+    # interpolate position embedding
+    embedding_size = pos_embed_checkpoint.shape[-1]
+    num_patches = visual_encoder.patch_embed.num_patches
+    num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+    # height (== width) for the checkpoint position embedding
+    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+    # height (== width) for the new position embedding
+    new_size = int(num_patches ** 0.5)
+
+    if orig_size!=new_size:
+        # class_token and dist_token are kept unchanged
+        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+        # only the position tokens are interpolated
+        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+        pos_tokens = torch.nn.functional.interpolate(
+            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+        print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+        
+        return new_pos_embed    
+    else:
+        return pos_embed_checkpoint
\ No newline at end of file
diff --git a/models/xbert.py b/models/xbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1f7774524bacc0c91a15ec66a8063de3f332a2
--- /dev/null
+++ b/models/xbert.py
@@ -0,0 +1,1916 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model. """
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss, MSELoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+    ModelOutput,
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    replace_return_docstrings,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    NextSentencePredictorOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "BertConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+
+BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+    "bert-base-uncased",
+    "bert-large-uncased",
+    "bert-base-cased",
+    "bert-large-cased",
+    "bert-base-multilingual-uncased",
+    "bert-base-multilingual-cased",
+    "bert-base-chinese",
+    "bert-base-german-cased",
+    "bert-large-uncased-whole-word-masking",
+    "bert-large-cased-whole-word-masking",
+    "bert-large-uncased-whole-word-masking-finetuned-squad",
+    "bert-large-cased-whole-word-masking-finetuned-squad",
+    "bert-base-cased-finetuned-mrpc",
+    "bert-base-german-dbmdz-cased",
+    "bert-base-german-dbmdz-uncased",
+    "cl-tohoku/bert-base-japanese",
+    "cl-tohoku/bert-base-japanese-whole-word-masking",
+    "cl-tohoku/bert-base-japanese-char",
+    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
+    "TurkuNLP/bert-base-finnish-cased-v1",
+    "TurkuNLP/bert-base-finnish-uncased-v1",
+    "wietsedv/bert-base-dutch-cased",
+    # See all BERT models at https://huggingface.co/models?filter=bert
+]
+
+
+def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
+    """Load tf checkpoints in a pytorch model."""
+    try:
+        import re
+
+        import numpy as np
+        import tensorflow as tf
+    except ImportError:
+        logger.error(
+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+            "https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+    tf_path = os.path.abspath(tf_checkpoint_path)
+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
+    # Load weights from TF model
+    init_vars = tf.train.list_variables(tf_path)
+    names = []
+    arrays = []
+    for name, shape in init_vars:
+        logger.info("Loading TF weight {} with shape {}".format(name, shape))
+        array = tf.train.load_variable(tf_path, name)
+        names.append(name)
+        arrays.append(array)
+
+    for name, array in zip(names, arrays):
+        name = name.split("/")
+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+        # which are not required for using pretrained model
+        if any(
+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+            for n in name
+        ):
+            logger.info("Skipping {}".format("/".join(name)))
+            continue
+        pointer = model
+        for m_name in name:
+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+                scope_names = re.split(r"_(\d+)", m_name)
+            else:
+                scope_names = [m_name]
+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+                pointer = getattr(pointer, "weight")
+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+                pointer = getattr(pointer, "bias")
+            elif scope_names[0] == "output_weights":
+                pointer = getattr(pointer, "weight")
+            elif scope_names[0] == "squad":
+                pointer = getattr(pointer, "classifier")
+            else:
+                try:
+                    pointer = getattr(pointer, scope_names[0])
+                except AttributeError:
+                    logger.info("Skipping {}".format("/".join(name)))
+                    continue
+            if len(scope_names) >= 2:
+                num = int(scope_names[1])
+                pointer = pointer[num]
+        if m_name[-11:] == "_embeddings":
+            pointer = getattr(pointer, "weight")
+        elif m_name == "kernel":
+            array = np.transpose(array)
+        try:
+            assert (
+                pointer.shape == array.shape
+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+        except AssertionError as e:
+            e.args += (pointer.shape, array.shape)
+            raise
+        logger.info("Initialize PyTorch weight {}".format(name))
+        pointer.data = torch.from_numpy(array)
+    return model
+
+
+class BertEmbeddings(nn.Module):
+    """Construct the embeddings from word, position and token_type embeddings."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        
+        self.config = config
+
+    def forward(
+        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+    ):
+        if input_ids is not None:
+            input_shape = input_ids.size()
+        else:
+            input_shape = inputs_embeds.size()[:-1]
+
+        seq_length = input_shape[1]
+
+        if position_ids is None:
+            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+        
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)  
+
+        embeddings = inputs_embeds + token_type_embeddings
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, is_cross_attention):
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                "The hidden size (%d) is not a multiple of the number of attention "
+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+            )
+        
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        if is_cross_attention:
+            self.key = nn.Linear(config.encoder_width, self.all_head_size)
+            self.value = nn.Linear(config.encoder_width, self.all_head_size)
+        else:
+            self.key = nn.Linear(config.hidden_size, self.all_head_size)
+            self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+        self.save_attention = False   
+            
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+        
+    def get_attn_gradients(self):
+        return self.attn_gradients
+    
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+        
+    def get_attention_map(self):
+        return self.attention_map
+    
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+
+        if is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        past_key_value = (key_layer, value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+        
+        if is_cross_attention and self.save_attention:
+            self.save_attention_map(attention_probs)
+            attention_probs.register_hook(self.save_attn_gradients)         
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs_dropped = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs_dropped = attention_probs_dropped * head_mask
+
+        context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, is_cross_attention=False):
+        super().__init__()
+        self.self = BertSelfAttention(config, is_cross_attention)
+        self.output = BertSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+    def __init__(self, config, layer_num):
+        super().__init__()
+        self.config = config
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = BertAttention(config)
+
+        self.has_cross_attention = (layer_num >= config.fusion_layer)
+        if self.has_cross_attention:           
+            self.layer_num = layer_num                
+            self.crossattention = BertAttention(config, is_cross_attention=True)
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+        )
+        attention_output = self_attention_outputs[0]
+
+        outputs = self_attention_outputs[1:-1]
+        present_key_value = self_attention_outputs[-1]
+
+        if self.has_cross_attention:
+            assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+            
+            if type(encoder_hidden_states) == list:
+                cross_attention_outputs = self.crossattention(
+                    attention_output,
+                    attention_mask,
+                    head_mask,
+                    encoder_hidden_states[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],
+                    encoder_attention_mask[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],
+                    output_attentions=output_attentions,
+                )    
+                attention_output = cross_attention_outputs[0]
+                outputs = outputs + cross_attention_outputs[1:-1]
+         
+            else:
+                cross_attention_outputs = self.crossattention(
+                    attention_output,
+                    attention_mask,
+                    head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    output_attentions=output_attentions,
+                )
+                attention_output = cross_attention_outputs[0]
+                outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+
+        outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        mode='multi_modal',
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        next_decoder_cache = () if use_cache else None
+        
+                
+        if mode=='text': 
+            start_layer = 0
+            output_layer = self.config.fusion_layer
+            
+        elif mode=='fusion':
+            start_layer = self.config.fusion_layer
+            output_layer = self.config.num_hidden_layers
+            
+        elif mode=='multi_modal':
+            start_layer = 0
+            output_layer = self.config.num_hidden_layers        
+        
+        for i in range(start_layer, output_layer):
+            layer_module = self.layer[i]
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+                if use_cache:
+                    logger.warn(
+                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
+                        "`use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs, past_key_value, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class BertPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+
+    def forward(self, sequence_output):
+        prediction_scores = self.predictions(sequence_output)
+        return prediction_scores
+
+
+class BertOnlyNSPHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+    def forward(self, pooled_output):
+        seq_relationship_score = self.seq_relationship(pooled_output)
+        return seq_relationship_score
+
+
+class BertPreTrainingHeads(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+        self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+    def forward(self, sequence_output, pooled_output):
+        prediction_scores = self.predictions(sequence_output)
+        seq_relationship_score = self.seq_relationship(pooled_output)
+        return prediction_scores, seq_relationship_score
+
+
+class BertPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = BertConfig
+    load_tf_weights = load_tf_weights_in_bert
+    base_model_prefix = "bert"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def _init_weights(self, module):
+        """ Initialize the weights """
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+@dataclass
+class BertForPreTrainingOutput(ModelOutput):
+    """
+    Output type of :class:`~transformers.BertForPreTraining`.
+    Args:
+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
+            Total loss as the sum of the masked language modeling loss and the next sequence prediction
+            (classification) loss.
+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
+            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+            before SoftMax).
+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+            sequence_length, sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    prediction_logits: torch.FloatTensor = None
+    seq_relationship_logits: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+BERT_START_DOCSTRING = r"""
+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
+    pruning heads etc.)
+    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+    general usage and behavior.
+    Parameters:
+        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+            weights.
+"""
+
+BERT_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
+            Indices of input sequence tokens in the vocabulary.
+            Indices can be obtained using :class:`~transformers.BertTokenizer`. See
+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
+            details.
+            `What are input IDs? <../glossary.html#input-ids>`__
+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+            `What are attention masks? <../glossary.html#attention-mask>`__
+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+            1]``:
+            - 0 corresponds to a `sentence A` token,
+            - 1 corresponds to a `sentence B` token.
+            `What are token type IDs? <../glossary.html#token-type-ids>`_
+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+            config.max_position_embeddings - 1]``.
+            `What are position IDs? <../glossary.html#position-ids>`_
+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
+            vectors than the model's internal embedding lookup matrix.
+        output_attentions (:obj:`bool`, `optional`):
+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+            tensors for more detail.
+        output_hidden_states (:obj:`bool`, `optional`):
+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+            more detail.
+        return_dict (:obj:`bool`, `optional`):
+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+    BERT_START_DOCSTRING,
+)
+class BertModel(BertPreTrainedModel):
+    """
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+    input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=True):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = BertEmbeddings(config)
+        
+        self.encoder = BertEncoder(config)
+
+        self.pooler = BertPooler(config) if add_pooling_layer else None
+
+        self.init_weights()
+ 
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    
+    
+    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if is_decoder:
+                batch_size, seq_length = input_shape
+                seq_ids = torch.arange(seq_length, device=device)
+                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+                # causal and attention masks must have same type with pytorch version < 1.3
+                causal_mask = causal_mask.to(attention_mask.dtype)
+
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    causal_mask = torch.cat(
+                        [
+                            torch.ones(
+                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
+                            ),
+                            causal_mask,
+                        ],
+                        axis=-1,
+                    )
+
+                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+                    input_shape, attention_mask.shape
+                )
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        return extended_attention_mask
+    
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+        mode='multi_modal',
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if is_decoder:
+            use_cache = use_cache if use_cache is not None else self.config.use_cache
+        else:
+            use_cache = False
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            batch_size, seq_length = input_shape
+            device = input_ids.device
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size, seq_length = input_shape
+            device = inputs_embeds.device
+        elif encoder_embeds is not None:    
+            input_shape = encoder_embeds.size()[:-1]
+            batch_size, seq_length = input_shape 
+            device = encoder_embeds.device
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if attention_mask is None:
+            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 
+                                                                                 device, is_decoder)
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if encoder_hidden_states is not None:
+            if type(encoder_hidden_states) == list:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+            else:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            
+            if type(encoder_attention_mask) == list:
+                encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+            elif encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+            else:    
+                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+        
+        if encoder_embeds is None:
+            embedding_output = self.embeddings(
+                input_ids=input_ids,
+                position_ids=position_ids,
+                token_type_ids=token_type_ids,
+                inputs_embeds=inputs_embeds,
+                past_key_values_length=past_key_values_length,
+            )
+        else:
+            embedding_output = encoder_embeds
+            
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            mode=mode,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+    sentence prediction (classification)` head.
+    """,
+    BERT_START_DOCSTRING,
+)
+class BertForPreTraining(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config)
+        self.cls = BertPreTrainingHeads(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        next_sentence_label=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
+            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
+            - 0 indicates sequence B is a continuation of sequence A,
+            - 1 indicates sequence B is a random sequence.
+        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+            Used to hide legacy arguments that have been deprecated.
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertForPreTraining
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+            >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
+            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+            >>> outputs = model(**inputs)
+            >>> prediction_logits = outputs.prediction_logits
+            >>> seq_relationship_logits = outputs.seq_relationship_logits
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output, pooled_output = outputs[:2]
+        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+        total_loss = None
+        if labels is not None and next_sentence_label is not None:
+            loss_fct = CrossEntropyLoss()
+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+            total_loss = masked_lm_loss + next_sentence_loss
+
+        if not return_dict:
+            output = (prediction_scores, seq_relationship_score) + outputs[2:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return BertForPreTrainingOutput(
+            loss=total_loss,
+            prediction_logits=prediction_scores,
+            seq_relationship_logits=seq_relationship_score,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
+)
+class BertLMHeadModel(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=True,
+        reduction='mean',
+        mode='multi_modal',
+        soft_labels=None,
+        alpha=0,
+        return_logits=False,        
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+            >>> config = BertConfig.from_pretrained("bert-base-cased")
+            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+            >>> outputs = model(**inputs)
+            >>> prediction_logits = outputs.logits
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if labels is not None:
+            use_cache = False
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+            mode=mode,
+        )
+        
+        sequence_output = outputs[0]
+        prediction_scores = self.cls(sequence_output)
+        
+        if return_logits:
+            return prediction_scores[:, :-1, :].contiguous()  
+
+        lm_loss = None
+        if labels is not None:
+            # we are doing next-token prediction; shift prediction scores and input ids by one
+            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+            labels = labels[:, 1:].contiguous()
+            loss_fct = CrossEntropyLoss(reduction=reduction)
+            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+            lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+            
+        if soft_labels is not None:
+            loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=-1)*soft_labels,dim=-1)
+            loss_distill = (loss_distill * (labels!=-100)).sum(1)
+            lm_loss = (1-alpha)*lm_loss + alpha*loss_distill                    
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((lm_loss,) + output) if lm_loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=lm_loss,
+            logits=prediction_scores,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+        input_shape = input_ids.shape
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_shape)
+
+        # cut decoder_input_ids if past is used
+        if past is not None:
+            input_ids = input_ids[:, -1:]
+
+        return {
+            "input_ids": input_ids, 
+            "attention_mask": attention_mask, 
+            "past_key_values": past,
+            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+            "is_decoder": True,
+        }
+
+    def _reorder_cache(self, past, beam_idx):
+        reordered_past = ()
+        for layer_past in past:
+            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+        return reordered_past
+
+
+@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
+class BertForMaskedLM(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=MaskedLMOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+        mode='multi_modal',
+        soft_labels=None,
+        alpha=0,
+        return_logits=False,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_embeds=encoder_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+            mode=mode,
+        )
+
+        sequence_output = outputs[0]
+        prediction_scores = self.cls(sequence_output)
+        
+        if return_logits:
+            return prediction_scores
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()  # -100 index = padding token
+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+        
+        if soft_labels is not None:
+            loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=-1)*soft_labels,dim=-1)
+            loss_distill = loss_distill[labels!=-100].mean()
+            masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return MaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+        input_shape = input_ids.shape
+        effective_batch_size = input_shape[0]
+
+        #  add a dummy token
+        assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
+        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+        dummy_token = torch.full(
+            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+        )
+        input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+        return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+    """Bert Model with a `next sentence prediction (classification)` head on top. """,
+    BERT_START_DOCSTRING,
+)
+class BertForNextSentencePrediction(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config)
+        self.cls = BertOnlyNSPHead(config)
+
+        self.init_weights()
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        **kwargs
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+            (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
+            - 0 indicates sequence B is a continuation of sequence A,
+            - 1 indicates sequence B is a random sequence.
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertForNextSentencePrediction
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+            >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
+            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+            >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+            >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
+            >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+            >>> logits = outputs.logits
+            >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+        """
+
+        if "next_sentence_label" in kwargs:
+            warnings.warn(
+                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+                FutureWarning,
+            )
+            labels = kwargs.pop("next_sentence_label")
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        pooled_output = outputs[1]
+
+        seq_relationship_scores = self.cls(pooled_output)
+
+        next_sentence_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+        if not return_dict:
+            output = (seq_relationship_scores,) + outputs[2:]
+            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+        return NextSentencePredictorOutput(
+            loss=next_sentence_loss,
+            logits=seq_relationship_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+    output) e.g. for GLUE tasks.
+    """,
+    BERT_START_DOCSTRING,
+)
+class BertForSequenceClassification(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.bert = BertModel(config)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+        self.init_weights()
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=SequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        pooled_output = outputs[1]
+
+        pooled_output = self.dropout(pooled_output)
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            if self.num_labels == 1:
+                #  We are doing regression
+                loss_fct = MSELoss()
+                loss = loss_fct(logits.view(-1), labels.view(-1))
+            else:
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+    softmax) e.g. for RocStories/SWAG tasks.
+    """,
+    BERT_START_DOCSTRING,
+)
+class BertForMultipleChoice(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = nn.Linear(config.hidden_size, 1)
+
+        self.init_weights()
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=MultipleChoiceModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
+            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
+            :obj:`input_ids` above)
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+        inputs_embeds = (
+            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+            if inputs_embeds is not None
+            else None
+        )
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        pooled_output = outputs[1]
+
+        pooled_output = self.dropout(pooled_output)
+        logits = self.classifier(pooled_output)
+        reshaped_logits = logits.view(-1, num_choices)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(reshaped_logits, labels)
+
+        if not return_dict:
+            output = (reshaped_logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return MultipleChoiceModelOutput(
+            loss=loss,
+            logits=reshaped_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+    Named-Entity-Recognition (NER) tasks.
+    """,
+    BERT_START_DOCSTRING,
+)
+class BertForTokenClassification(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+        self.init_weights()
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=TokenClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+            1]``.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        sequence_output = self.dropout(sequence_output)
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            # Only keep active parts of the loss
+            if attention_mask is not None:
+                active_loss = attention_mask.view(-1) == 1
+                active_logits = logits.view(-1, self.num_labels)
+                active_labels = torch.where(
+                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                )
+                loss = loss_fct(active_logits, active_labels)
+            else:
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+    """,
+    BERT_START_DOCSTRING,
+)
+class BertForQuestionAnswering(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+        self.init_weights()
+
+    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        tokenizer_class=_TOKENIZER_FOR_DOC,
+        checkpoint="bert-base-uncased",
+        output_type=QuestionAnsweringModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        start_positions=None,
+        end_positions=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for position (index) of the start of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+            sequence are not taken into account for computing the loss.
+        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for position (index) of the end of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+            sequence are not taken into account for computing the loss.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1)
+        end_logits = end_logits.squeeze(-1)
+
+        total_loss = None
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, split add a dimension
+            if len(start_positions.size()) > 1:
+                start_positions = start_positions.squeeze(-1)
+            if len(end_positions.size()) > 1:
+                end_positions = end_positions.squeeze(-1)
+            # sometimes the start/end positions are outside our model inputs, we ignore these terms
+            ignored_index = start_logits.size(1)
+            start_positions.clamp_(0, ignored_index)
+            end_positions.clamp_(0, ignored_index)
+
+            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+        if not return_dict:
+            output = (start_logits, end_logits) + outputs[2:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return QuestionAnsweringModelOutput(
+            loss=total_loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d47697c44a93aa9ee6f091e0be56aa73a38d8d77
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,152 @@
+absl-py==1.2.0
+aiohttp==3.8.3
+aiosignal==1.2.0
+anyio==3.6.1
+asttokens==2.0.8
+async-timeout==4.0.2
+attrs==22.1.0
+av==9.2.0
+backcall==0.2.0
+bcrypt==4.0.0
+cachetools==5.2.0
+certifi==2022.9.14
+cffi==1.15.1
+charset-normalizer==2.1.1
+click==8.1.3
+cloudpickle==2.2.0
+configobj==5.0.6
+contourpy==1.0.5
+cryptography==38.0.1
+cycler==0.11.0
+cytoolz==0.12.0
+debugpy==1.6.3
+decorator==5.1.1
+decord==0.6.0
+easydict==1.10
+einops==0.4.1
+entrypoints==0.4
+executing==1.1.0
+fairscale==0.4.12
+fastapi==0.85.0
+ffmpy==0.3.0
+filelock==3.8.0
+fonttools==4.37.3
+frozenlist==1.3.1
+fsspec==2022.8.2
+ftfy==6.1.1
+google-auth==2.12.0
+google-auth-oauthlib==0.4.6
+gradio==3.4.0
+grpcio==1.49.1
+h11==0.12.0
+httpcore==0.15.0
+httpx==0.23.0
+huggingface-hub==0.9.1
+idna==3.4
+imageio==2.22.1
+importlib-metadata==5.0.0
+inflect==6.0.0
+ipdb==0.13.9
+ipykernel==6.16.0
+ipython==8.5.0
+jedi==0.18.1
+Jinja2==3.1.2
+joblib==1.2.0
+jupyter-core==4.11.1
+jupyter_client==7.3.5
+kiwisolver==1.4.4
+linkify-it-py==1.0.3
+lmdb==1.3.0
+lz4==4.0.2
+Markdown==3.4.1
+markdown-it-py==2.1.0
+MarkupSafe==2.1.1
+matplotlib==3.6.0
+matplotlib-inline==0.1.6
+mdit-py-plugins==0.3.1
+mdurl==0.1.2
+msgpack==1.0.4
+msgpack-numpy==0.4.8
+multidict==6.0.2
+nest-asyncio==1.5.6
+networkx==2.8.7
+nltk==3.7
+numpy==1.23.3
+oauthlib==3.2.1
+opencv-python==4.6.0.66
+orjson==3.8.0
+packaging==21.3
+pandas==1.5.0
+paramiko==2.11.0
+parso==0.8.3
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==9.2.0
+Pillow-SIMD==9.0.0.post1
+prettytable==3.4.1
+prompt-toolkit==3.0.31
+protobuf==3.19.6
+psutil==5.9.2
+ptyprocess==0.7.0
+pure-eval==0.2.2
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycocotools==2.0.5
+pycparser==2.21
+pycryptodome==3.15.0
+pydantic==1.10.2
+pydub==0.25.1
+Pygments==2.13.0
+pymongo==4.2.0
+PyNaCl==1.5.0
+pyparsing==3.0.9
+python-dateutil==2.8.2
+python-multipart==0.0.5
+pytz==2022.4
+PyWavelets==1.4.1
+PyYAML==6.0
+pyzmq==24.0.1
+regex==2022.9.13
+requests==2.28.1
+requests-oauthlib==1.3.1
+rfc3986==1.5.0
+rsa==4.9
+ruamel.yaml==0.17.21
+ruamel.yaml.base==0.3.0
+ruamel.yaml.clib==0.2.6
+ruamel.yaml.cmd==0.6.3
+ruamel.yaml.convert==0.3.2
+sacremoses==0.0.53
+scikit-image==0.19.3
+scipy==1.9.1
+Shapely==1.8.4
+six==1.16.0
+sniffio==1.3.0
+stack-data==0.5.1
+starlette==0.20.4
+tensorboard==2.10.1
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+tensorboardX==2.5.1
+tifffile==2022.8.12
+timm==0.6.7
+tokenizers==0.10.3
+toml==0.10.2
+toolz==0.12.0
+torch==1.10.0+cu113
+torchvision==0.11.0+cu113
+tornado==6.2
+tqdm==4.64.1
+traitlets==5.4.0
+transformers==4.11.3
+typing_extensions==4.3.0
+uc-micro-py==1.0.1
+ujson==5.5.0
+urllib3==1.26.12
+uvicorn==0.18.3
+wcwidth==0.2.5
+websockets==10.3
+Werkzeug==2.2.2
+yacs==0.1.8
+yarl==1.8.1
+zipp==3.9.0
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..837c2cd15f4624f630540ef6993dcb9123adb39b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,69 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#!/usr/bin/env python
+
+import glob
+import os
+
+import torch
+from setuptools import find_packages
+from setuptools import setup
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+requirements = ["torch", "torchvision"]
+
+
+def get_extensions():
+    this_dir = os.path.dirname(os.path.abspath(__file__))
+    extensions_dir = os.path.join(this_dir, "maskrcnn_benchmark", "csrc")
+
+    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+    source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+    source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+    sources = main_file + source_cpu
+    extension = CppExtension
+
+    extra_compile_args = {"cxx": []}
+    define_macros = []
+
+    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
+        extension = CUDAExtension
+        sources += source_cuda
+        define_macros += [("WITH_CUDA", None)]
+        extra_compile_args["nvcc"] = [
+            "-DCUDA_HAS_FP16=1",
+            "-D__CUDA_NO_HALF_OPERATORS__",
+            "-D__CUDA_NO_HALF_CONVERSIONS__",
+            "-D__CUDA_NO_HALF2_OPERATORS__",
+        ]
+
+    sources = [os.path.join(extensions_dir, s) for s in sources]
+
+    include_dirs = [extensions_dir]
+
+    ext_modules = [
+        extension(
+            "maskrcnn_benchmark._C",
+            sources,
+            include_dirs=include_dirs,
+            define_macros=define_macros,
+            extra_compile_args=extra_compile_args,
+        )
+    ]
+
+    return ext_modules
+
+
+setup(
+    name="maskrcnn_benchmark",
+    version="0.1",
+    author="fmassa",
+    url="https://github.com/facebookresearch/maskrcnn-benchmark",
+    description="object detection in pytorch",
+    packages=find_packages(exclude=("configs", "tests",)),
+    # install_requires=requirements,
+    ext_modules=get_extensions(),
+    cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/vqa.py b/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..f90b1e5469705a89755fb2bebe93ea966f36dcea
--- /dev/null
+++ b/vqa.py
@@ -0,0 +1,127 @@
+import sys
+from PIL import Image
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from models.blip_vqa import blip_vqa
+import cv2
+import numpy as np
+import matplotlib.image as mpimg
+
+from skimage import transform as skimage_transform
+from scipy.ndimage import filters
+from matplotlib import pyplot as plt
+
+
+import torch
+from torch import nn
+from torchvision import transforms
+
+import json
+import traceback
+
+class VQA:
+    def __init__(self, model_path, image_size=480):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base')
+        self.block_num = 9
+        self.model.eval()
+        self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.save_attention = True
+
+        self.model = self.model.to(self.device)
+    def getAttMap(self, img, attMap, blur = True, overlap = True):
+        attMap -= attMap.min()
+        if attMap.max() > 0:
+            attMap /= attMap.max()
+        attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')
+        if blur:
+            attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))
+            attMap -= attMap.min()
+            attMap /= attMap.max()
+        cmap = plt.get_cmap('jet')
+        attMapV = cmap(attMap)
+        attMapV = np.delete(attMapV, 3, 2)
+        if overlap:
+            attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV
+        return attMap
+
+    def gradcam(self, text_input, image_path, image):
+        mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)
+        grads = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attn_gradients()
+        cams = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attention_map()
+        cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 30, 30) * mask
+        grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 30, 30) * mask
+        gradcam = cams * grads
+        gradcam = gradcam[0].mean(0).cpu().detach()
+
+        num_image = len(text_input.input_ids[0])
+        num_image -= 1
+        fig, ax = plt.subplots(num_image, 1, figsize=(15,15*num_image))
+
+        rgb_image = cv2.imread(image_path)[:, :, ::-1]
+        rgb_image = np.float32(rgb_image) / 255
+        ax[0].imshow(rgb_image)
+        ax[0].set_yticks([])
+        ax[0].set_xticks([])
+        ax[0].set_xlabel("Image")
+
+        for i,token_id in enumerate(text_input.input_ids[0][1:-1]):
+            word = self.model.tokenizer.decode([token_id])
+            gradcam_image = self.getAttMap(rgb_image, gradcam[i+1])
+            ax[i+1].imshow(gradcam_image)
+            ax[i+1].set_yticks([])
+            ax[i+1].set_xticks([])
+            ax[i+1].set_xlabel(word)
+        
+        plt.show()
+
+
+    def load_demo_image(self, image_size, img_path, device):
+        raw_image = Image.open(img_path).convert('RGB')   
+        w,h = raw_image.size
+        transform = transforms.Compose([
+            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
+            transforms.ToTensor(),
+            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+            ]) 
+        image = transform(raw_image).unsqueeze(0).to(device)   
+        return raw_image, image
+
+    def vqa(self, img_path, question):
+        raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device)        
+        answer, vl_output, que = self.model(image, question, mode='gradcam', inference='generate')
+        loss = vl_output[:,1].sum()
+        self.model.zero_grad()
+        loss.backward()
+
+        with torch.no_grad():
+            self.gradcam(que, img_path, image)
+        
+        return answer[0]
+
+    def vqa_demo(self, image, question):
+        image_size = 480
+        transform = transforms.Compose([
+            transforms.ToPILImage(),
+            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
+            transforms.ToTensor(),
+            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+            ]) 
+        image = transform(image).unsqueeze(0).to(self.device)
+        answer = self.model(image, question, mode='inference', inference='generate')
+        
+        return answer[0]
+
+
+if __name__=="__main__":
+    if not len(sys.argv) == 3:
+        print('Format: python3 vqa.py <path_to_img> <question>')
+        print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"')
+        
+    else:
+        model_path = 'checkpoints/model_base_vqa_capfilt_large.pth'
+        vqa_object = VQA(model_path=model_path)
+        img_path = sys.argv[1]
+        question = sys.argv[2]
+        answer = vqa_object.vqa(img_path, question)
+        print('Question: {} | Answer: {}'.format(question, answer))
\ No newline at end of file