diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..aa7b75fdd0009905fac2cb3f8cb377472fae7984 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+unimernet/processors/formula_processor_helper/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index aa62729d994fe92209f41eebd524c8a08d46b06c..b0c860c3307c1b9d3f273e1243bd9c7b6ee26201 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,11 @@
----
-title: UniMERNet Demo
-emoji: 🌖
-colorFrom: purple
-colorTo: blue
-sdk: gradio
-sdk_version: 4.43.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+---
+title: UniMERNet
+emoji: 👁
+colorFrom: purple
+colorTo: blue
+sdk: gradio
+sdk_version: 4.42.0
+app_file: app.py
+pinned: false
+license: apache-2.0
+---
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..350fd6ec39d4d3597012b72a0269fa41358d7c32
--- /dev/null
+++ b/app.py
@@ -0,0 +1,106 @@
+import os
+os.system('pip install -U transformers==4.44.2')
+import sys
+import shutil
+import torch
+import argparse
+import gradio as gr
+import numpy as np
+from PIL import Image
+from huggingface_hub import snapshot_download
+import spaces
+
+# == download weights ==
+tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
+small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
+base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
+os.system("ls -l models/unimernet_tiny")
+# os.system(f"sed -i 's/MODEL_DIR/{tiny_model_dir}/g' cfg_tiny.yaml")
+# os.system(f"sed -i 's/MODEL_DIR/{small_model_dir}/g' cfg_small.yaml")
+# os.system(f"sed -i 's/MODEL_DIR/{base_model_dir}/g' cfg_base.yaml")
+# root_path = os.path.abspath(os.getcwd())
+# os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
+# shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
+# shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
+# shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
+# == download weights ==
+
+sys.path.insert(0, os.path.join(os.getcwd(), ".."))
+from unimernet.common.config import Config
+import unimernet.tasks as tasks
+from unimernet.processors import load_processor
+
+
+def load_model_and_processor(cfg_path):
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
+ cfg = Config(args)
+ task = tasks.setup_task(cfg)
+ model = task.build_model(cfg)
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
+ return model, vis_processor
+
+@spaces.GPU
+def recognize_image(input_img, model_type):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if model_type == "base":
+ model = model_base.to(device)
+ elif model_type == "small":
+ model = model_small.to(device)
+ else:
+ model = model_tiny.to(device)
+
+ if len(input_img.shape) == 3:
+ input_img = input_img[:, :, ::-1].copy()
+
+ img = Image.fromarray(input_img)
+ image = vis_processor(img).unsqueeze(0).to(device)
+ output = model.generate({"image": image})
+ latex_code = output["pred_str"][0]
+ return latex_code
+
+def gradio_reset():
+ return gr.update(value=None), gr.update(value=None)
+
+
+if __name__ == "__main__":
+ root_path = os.path.abspath(os.getcwd())
+ # == load model ==
+ model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
+ model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
+ model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
+ print("== load all models ==")
+ # == load model ==
+
+ with open("header.html", "r") as file:
+ header = file.read()
+ with gr.Blocks() as demo:
+ gr.HTML(header)
+
+ with gr.Row():
+ with gr.Column():
+ model_type = gr.Radio(
+ choices=["tiny", "small", "base"],
+ value="tiny",
+ label="Model Type",
+ interactive=True,
+ )
+ input_img = gr.Image(label=" ", interactive=True)
+ with gr.Row():
+ clear = gr.Button("Clear")
+ predict = gr.Button(value="Recognize", interactive=True, variant="primary")
+
+ with gr.Accordion("Examples:"):
+ example_root = os.path.join(os.path.dirname(__file__), "examples")
+ gr.Examples(
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
+ _.endswith("png")],
+ inputs=input_img,
+ )
+ with gr.Column():
+ gr.Button(value="Predict Latex:", interactive=False)
+ pred_latex = gr.Textbox(label='Latex', interactive=False)
+
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
+ predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
+
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
\ No newline at end of file
diff --git a/cfg_base.yaml b/cfg_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..473fef8fb2e85156876c963d461375c15aa4cf64
--- /dev/null
+++ b/cfg_base.yaml
@@ -0,0 +1,46 @@
+model:
+ arch: unimernet
+ model_type: unimernet
+ model_config:
+ model_name: ./models/unimernet_base
+ max_seq_len: 1536
+
+ load_pretrained: True
+ pretrained: './models/unimernet_base/unimernet_base.pth'
+ tokenizer_config:
+ path: ./models/unimernet_base
+
+datasets:
+ formula_rec_eval:
+ vis_processor:
+ eval:
+ name: "formula_image_eval"
+ image_size:
+ - 192
+ - 672
+
+run:
+ runner: runner_iter
+ task: unimernet_train
+
+ batch_size_train: 64
+ batch_size_eval: 64
+ num_workers: 1
+
+ iters_per_inner_epoch: 2000
+ max_iters: 60000
+
+ seed: 42
+ output_dir: "../output/demo"
+
+ evaluate: True
+ test_splits: [ "eval" ]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
+ distributed_type: ddp # or fsdp when train llm
+
+ generate_cfg:
+ temperature: 0.0
\ No newline at end of file
diff --git a/cfg_small.yaml b/cfg_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb88dd87a21b728c726cd8576afcae5549e26ee9
--- /dev/null
+++ b/cfg_small.yaml
@@ -0,0 +1,46 @@
+model:
+ arch: unimernet
+ model_type: unimernet
+ model_config:
+ model_name: ./models/unimernet_small
+ max_seq_len: 1536
+
+ load_pretrained: True
+ pretrained: './models/unimernet_small/unimernet_small.pth'
+ tokenizer_config:
+ path: ./models/unimernet_small
+
+datasets:
+ formula_rec_eval:
+ vis_processor:
+ eval:
+ name: "formula_image_eval"
+ image_size:
+ - 192
+ - 672
+
+run:
+ runner: runner_iter
+ task: unimernet_train
+
+ batch_size_train: 64
+ batch_size_eval: 64
+ num_workers: 1
+
+ iters_per_inner_epoch: 2000
+ max_iters: 60000
+
+ seed: 42
+ output_dir: "../output/demo"
+
+ evaluate: True
+ test_splits: [ "eval" ]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
+ distributed_type: ddp # or fsdp when train llm
+
+ generate_cfg:
+ temperature: 0.0
\ No newline at end of file
diff --git a/cfg_tiny.yaml b/cfg_tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f16ebfda9ac094e551ede4388ffe805ece56be0f
--- /dev/null
+++ b/cfg_tiny.yaml
@@ -0,0 +1,46 @@
+model:
+ arch: unimernet
+ model_type: unimernet
+ model_config:
+ model_name: ./models/unimernet_tiny
+ max_seq_len: 1536
+
+ load_pretrained: True
+ pretrained: './models/unimernet_tiny/unimernet_tiny.pth'
+ tokenizer_config:
+ path: ./models/unimernet_tiny
+
+datasets:
+ formula_rec_eval:
+ vis_processor:
+ eval:
+ name: "formula_image_eval"
+ image_size:
+ - 192
+ - 672
+
+run:
+ runner: runner_iter
+ task: unimernet_train
+
+ batch_size_train: 64
+ batch_size_eval: 64
+ num_workers: 1
+
+ iters_per_inner_epoch: 2000
+ max_iters: 60000
+
+ seed: 42
+ output_dir: "../output/demo"
+
+ evaluate: True
+ test_splits: [ "eval" ]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
+ distributed_type: ddp # or fsdp when train llm
+
+ generate_cfg:
+ temperature: 0.0
\ No newline at end of file
diff --git a/examples/0000004.png b/examples/0000004.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0006f68e8c1e18e258baaab3263228c461d7f0c
Binary files /dev/null and b/examples/0000004.png differ
diff --git a/examples/0000005.png b/examples/0000005.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ec583b402563a2bf34cf66664c13f2c905f631e
Binary files /dev/null and b/examples/0000005.png differ
diff --git a/examples/0000006.png b/examples/0000006.png
new file mode 100644
index 0000000000000000000000000000000000000000..5acb9e4020fe7b1b5068dc21cc001b8d02f63096
Binary files /dev/null and b/examples/0000006.png differ
diff --git a/examples/0000007.png b/examples/0000007.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc807adb3ada1010f32b5b3c678d178fbe2444b0
Binary files /dev/null and b/examples/0000007.png differ
diff --git a/examples/0000011.png b/examples/0000011.png
new file mode 100644
index 0000000000000000000000000000000000000000..2417a6f37f86b585ce8ffc0527a1268f4ea1d3d6
Binary files /dev/null and b/examples/0000011.png differ
diff --git a/gitattributes b/gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..aa7b75fdd0009905fac2cb3f8cb377472fae7984
--- /dev/null
+++ b/gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+unimernet/processors/formula_processor_helper/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
diff --git a/header.html b/header.html
new file mode 100644
index 0000000000000000000000000000000000000000..aa6a5ce08973f4a0958e37375d86bafe11aab411
--- /dev/null
+++ b/header.html
@@ -0,0 +1,109 @@
+
+
+
+
+
+
+
+
+
+
+ A Universal Network for Real-World Mathematical Expression Recognition.
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c93d5554121f09f4b191ac555c6c45b4e04724ec
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+unimernet==0.1.6
+gradio
+huggingface_hub
\ No newline at end of file
diff --git a/unimernet/__init__.py b/unimernet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a759d516575604fdf99b56de5cdf149e579012
--- /dev/null
+++ b/unimernet/__init__.py
@@ -0,0 +1,31 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import sys
+
+from omegaconf import OmegaConf
+
+from unimernet.common.registry import registry
+
+from unimernet.datasets.builders import *
+from unimernet.models import *
+from unimernet.processors import *
+from unimernet.tasks import *
+
+
+root_dir = os.path.dirname(os.path.abspath(__file__))
+default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
+
+registry.register_path("library_root", root_dir)
+repo_root = os.path.join(root_dir, "..")
+registry.register_path("repo_root", repo_root)
+cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+registry.register_path("cache_root", cache_root)
+
+registry.register("MAX_INT", sys.maxsize)
+registry.register("SPLIT_NAMES", ["train", "val", "test"])
diff --git a/unimernet/__pycache__/__init__.cpython-310.pyc b/unimernet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edea971fd906b1fc9895653f06b14d6b98d4f387
Binary files /dev/null and b/unimernet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/common/__init__.py b/unimernet/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/common/__pycache__/__init__.cpython-310.pyc b/unimernet/common/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14c726cbd5b8ce37830e6d27ec8d481518520fe2
Binary files /dev/null and b/unimernet/common/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/config.cpython-310.pyc b/unimernet/common/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b8085e1fb2360a3236efb0f8d4f080bd5fb8f8f
Binary files /dev/null and b/unimernet/common/__pycache__/config.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/dist_utils.cpython-310.pyc b/unimernet/common/__pycache__/dist_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b714750191d75d4b2575059e05940017d2c395e1
Binary files /dev/null and b/unimernet/common/__pycache__/dist_utils.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/logger.cpython-310.pyc b/unimernet/common/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0964be7a43f1cd39e88bbdcf94aa45b88f831e27
Binary files /dev/null and b/unimernet/common/__pycache__/logger.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/registry.cpython-310.pyc b/unimernet/common/__pycache__/registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9101b0789cd5ffc61eab6ec231cc0661b0dc6e41
Binary files /dev/null and b/unimernet/common/__pycache__/registry.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/utils.cpython-310.pyc b/unimernet/common/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94155df43180ba04fbf358536c93dc7bd69ebcab
Binary files /dev/null and b/unimernet/common/__pycache__/utils.cpython-310.pyc differ
diff --git a/unimernet/common/config.py b/unimernet/common/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bfcff2e4c07557832501594a6a4ed0871c166fa
--- /dev/null
+++ b/unimernet/common/config.py
@@ -0,0 +1,468 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import json
+from typing import Dict
+
+from omegaconf import OmegaConf
+from unimernet.common.registry import registry
+
+
+class Config:
+ def __init__(self, args):
+ self.config = {}
+
+ self.args = args
+
+ # Register the config and configuration for setup
+ registry.register("configuration", self)
+
+ user_config = self._build_opt_list(self.args.options)
+
+ config = OmegaConf.load(self.args.cfg_path)
+
+ runner_config = self.build_runner_config(config)
+ model_config = self.build_model_config(config, **user_config)
+ dataset_config = self.build_dataset_config(config)
+
+ # Validate the user-provided runner configuration
+ # model and dataset configuration are supposed to be validated by the respective classes
+ # [TODO] validate the model/dataset configuration
+ # self._validate_runner_config(runner_config)
+
+ # Override the default configuration with user options.
+ self.config = OmegaConf.merge(
+ runner_config, model_config, dataset_config, user_config
+ )
+
+ def _validate_runner_config(self, runner_config):
+ """
+ This method validates the configuration, such that
+ 1) all the user specified options are valid;
+ 2) no type mismatches between the user specified options and the config.
+ """
+ runner_config_validator = create_runner_config_validator()
+ runner_config_validator.validate(runner_config)
+
+ def _build_opt_list(self, opts):
+ opts_dot_list = self._convert_to_dot_list(opts)
+ return OmegaConf.from_dotlist(opts_dot_list)
+
+ @staticmethod
+ def build_model_config(config, **kwargs):
+ model = config.get("model", None)
+ assert model is not None, "Missing model configuration file."
+
+ model_cls = registry.get_model_class(model.arch)
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
+
+ model_type = kwargs.get("model.model_type", None)
+ if not model_type:
+ model_type = model.get("model_type", None)
+ # else use the model type selected by user.
+
+ assert model_type is not None, "Missing model_type."
+
+ model_config_path = model_cls.default_config_path(model_type=model_type)
+
+ model_config = OmegaConf.create()
+ # hiararchy override, customized config > default config
+ model_config = OmegaConf.merge(
+ model_config,
+ OmegaConf.load(model_config_path),
+ {"model": config["model"]},
+ )
+
+ return model_config
+
+ @staticmethod
+ def build_runner_config(config):
+ return {"run": config.run}
+
+ @staticmethod
+ def build_dataset_config(config):
+ datasets = config.get("datasets", None)
+ if datasets is None:
+ raise KeyError(
+ "Expecting 'datasets' as the root key for dataset configuration."
+ )
+
+ dataset_config = OmegaConf.create()
+
+ for dataset_name in datasets:
+ builder_cls = registry.get_builder_class(dataset_name)
+
+ dataset_config_type = datasets[dataset_name].get("type", "default")
+ dataset_config_path = builder_cls.default_config_path(
+ type=dataset_config_type
+ )
+
+ # hiararchy override, customized config > default config
+ dataset_config = OmegaConf.merge(
+ dataset_config,
+ OmegaConf.load(dataset_config_path),
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
+ )
+
+ return dataset_config
+
+ def _convert_to_dot_list(self, opts):
+ if opts is None:
+ opts = []
+
+ if len(opts) == 0:
+ return opts
+
+ has_equal = opts[0].find("=") != -1
+
+ if has_equal:
+ return opts
+
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
+
+ def get_config(self):
+ return self.config
+
+ @property
+ def run_cfg(self):
+ return self.config.run
+
+ @property
+ def datasets_cfg(self):
+ return self.config.datasets
+
+ @property
+ def model_cfg(self):
+ return self.config.model
+
+ def pretty_print(self):
+ logging.info("\n===== Running Parameters =====")
+ logging.info(self._convert_node_to_json(self.config.run))
+
+ logging.info("\n====== Dataset Attributes ======")
+ datasets = self.config.datasets
+
+ for dataset in datasets:
+ if dataset in self.config.datasets:
+ logging.info(f"\n======== {dataset} =======")
+ dataset_config = self.config.datasets[dataset]
+ logging.info(self._convert_node_to_json(dataset_config))
+ else:
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
+
+ logging.info(f"\n====== Model Attributes ======")
+ logging.info(self._convert_node_to_json(self.config.model))
+
+ def _convert_node_to_json(self, node):
+ container = OmegaConf.to_container(node, resolve=True)
+ return json.dumps(container, indent=4, sort_keys=True)
+
+ def to_dict(self):
+ return OmegaConf.to_container(self.config)
+
+
+def node_to_dict(node):
+ return OmegaConf.to_container(node)
+
+
+class ConfigValidator:
+ """
+ This is a preliminary implementation to centralize and validate the configuration.
+ May be altered in the future.
+
+ A helper class to validate configurations from yaml file.
+
+ This serves the following purposes:
+ 1. Ensure all the options in the yaml are defined, raise error if not.
+ 2. when type mismatches are found, the validator will raise an error.
+ 3. a central place to store and display helpful messages for supported configurations.
+
+ """
+
+ class _Argument:
+ def __init__(self, name, choices=None, type=None, help=None):
+ self.name = name
+ self.val = None
+ self.choices = choices
+ self.type = type
+ self.help = help
+
+ def __str__(self):
+ s = f"{self.name}={self.val}"
+ if self.type is not None:
+ s += f", ({self.type})"
+ if self.choices is not None:
+ s += f", choices: {self.choices}"
+ if self.help is not None:
+ s += f", ({self.help})"
+ return s
+
+ def __init__(self, description):
+ self.description = description
+
+ self.arguments = dict()
+
+ self.parsed_args = None
+
+ def __getitem__(self, key):
+ assert self.parsed_args is not None, "No arguments parsed yet."
+
+ return self.parsed_args[key]
+
+ def __str__(self) -> str:
+ return self.format_help()
+
+ def add_argument(self, *args, **kwargs):
+ """
+ Assume the first argument is the name of the argument.
+ """
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
+
+ def validate(self, config=None):
+ """
+ Convert yaml config (dict-like) to list, required by argparse.
+ """
+ for k, v in config.items():
+ assert (
+ k in self.arguments
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
+
+ if self.arguments[k].type is not None:
+ try:
+ self.arguments[k].val = self.arguments[k].type(v)
+ except ValueError:
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
+
+ if self.arguments[k].choices is not None:
+ assert (
+ v in self.arguments[k].choices
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
+
+ return config
+
+ def format_arguments(self):
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
+
+ def format_help(self):
+ # description + key-value pair string for each argument
+ help_msg = str(self.description)
+ return help_msg + ", available arguments: " + self.format_arguments()
+
+ def print_help(self):
+ # display help message
+ print(self.format_help())
+
+
+def create_runner_config_validator():
+ validator = ConfigValidator(description="Runner configurations")
+
+ validator.add_argument(
+ "runner",
+ type=str,
+ choices=["runner_base", "runner_iter"],
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
+ runner runs based on iters. Default: runner_base""",
+ )
+ # add argumetns for training dataset ratios
+ validator.add_argument(
+ "train_dataset_ratios",
+ type=Dict[str, float],
+ help="""Ratios of training dataset. This is used in iteration-based runner.
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
+ Default: None""",
+ )
+ validator.add_argument(
+ "max_iters",
+ type=float,
+ help="Maximum number of iterations to run.",
+ )
+ validator.add_argument(
+ "max_epoch",
+ type=int,
+ help="Maximum number of epochs to run.",
+ )
+ # add arguments for iters_per_inner_epoch
+ validator.add_argument(
+ "iters_per_inner_epoch",
+ type=float,
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
+ )
+ lr_scheds_choices = registry.list_lr_schedulers()
+ validator.add_argument(
+ "lr_sched",
+ type=str,
+ choices=lr_scheds_choices,
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
+ )
+ task_choices = registry.list_tasks()
+ validator.add_argument(
+ "task",
+ type=str,
+ choices=task_choices,
+ help="Task to use, from {}".format(task_choices),
+ )
+ # add arguments for init_lr
+ validator.add_argument(
+ "init_lr",
+ type=float,
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
+ )
+ # add arguments for min_lr
+ validator.add_argument(
+ "min_lr",
+ type=float,
+ help="Minimum learning rate (after decay).",
+ )
+ # add arguments for warmup_lr
+ validator.add_argument(
+ "warmup_lr",
+ type=float,
+ help="Starting learning rate for warmup.",
+ )
+ # add arguments for learning rate decay rate
+ validator.add_argument(
+ "lr_decay_rate",
+ type=float,
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
+ )
+ # add arguments for weight decay
+ validator.add_argument(
+ "weight_decay",
+ type=float,
+ help="Weight decay rate.",
+ )
+ # add arguments for training batch size
+ validator.add_argument(
+ "batch_size_train",
+ type=int,
+ help="Training batch size.",
+ )
+ # add arguments for evaluation batch size
+ validator.add_argument(
+ "batch_size_eval",
+ type=int,
+ help="Evaluation batch size, including validation and testing.",
+ )
+ # add arguments for number of workers for data loading
+ validator.add_argument(
+ "num_workers",
+ help="Number of workers for data loading.",
+ )
+ # add arguments for warm up steps
+ validator.add_argument(
+ "warmup_steps",
+ type=int,
+ help="Number of warmup steps. Required if a warmup schedule is used.",
+ )
+ # add arguments for random seed
+ validator.add_argument(
+ "seed",
+ type=int,
+ help="Random seed.",
+ )
+ # add arguments for output directory
+ validator.add_argument(
+ "output_dir",
+ type=str,
+ help="Output directory to save checkpoints and logs.",
+ )
+ # add arguments for whether only use evaluation
+ validator.add_argument(
+ "evaluate",
+ help="Whether to only evaluate the model. If true, training will not be performed.",
+ )
+ # add arguments for splits used for training, e.g. ["train", "val"]
+ validator.add_argument(
+ "train_splits",
+ type=list,
+ help="Splits to use for training.",
+ )
+ # add arguments for splits used for validation, e.g. ["val"]
+ validator.add_argument(
+ "valid_splits",
+ type=list,
+ help="Splits to use for validation. If not provided, will skip the validation.",
+ )
+ # add arguments for splits used for testing, e.g. ["test"]
+ validator.add_argument(
+ "test_splits",
+ type=list,
+ help="Splits to use for testing. If not provided, will skip the testing.",
+ )
+ # add arguments for accumulating gradient for iterations
+ validator.add_argument(
+ "accum_grad_iters",
+ type=int,
+ help="Number of iterations to accumulate gradient for.",
+ )
+
+ # ====== distributed training ======
+ validator.add_argument(
+ "device",
+ type=str,
+ choices=["cpu", "cuda"],
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
+ )
+ validator.add_argument(
+ "world_size",
+ type=int,
+ help="Number of processes participating in the job.",
+ )
+ validator.add_argument("dist_url", type=str)
+ validator.add_argument("distributed", type=bool)
+ # add arguments to opt using distributed sampler during evaluation or not
+ validator.add_argument(
+ "use_dist_eval_sampler",
+ type=bool,
+ help="Whether to use distributed sampler during evaluation or not.",
+ )
+
+ # ====== task specific ======
+ # generation task specific arguments
+ # add arguments for maximal length of text output
+ validator.add_argument(
+ "max_len",
+ type=int,
+ help="Maximal length of text output.",
+ )
+ # add arguments for minimal length of text output
+ validator.add_argument(
+ "min_len",
+ type=int,
+ help="Minimal length of text output.",
+ )
+ # add arguments number of beams
+ validator.add_argument(
+ "num_beams",
+ type=int,
+ help="Number of beams used for beam search.",
+ )
+
+ # vqa task specific arguments
+ # add arguments for number of answer candidates
+ validator.add_argument(
+ "num_ans_candidates",
+ type=int,
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
+ )
+ # add arguments for inference method
+ validator.add_argument(
+ "inference_method",
+ type=str,
+ choices=["genearte", "rank"],
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
+ )
+
+ # ====== model specific ======
+ validator.add_argument(
+ "k_test",
+ type=int,
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
+ )
+
+ return validator
diff --git a/unimernet/common/dist_utils.py b/unimernet/common/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..296a3c86f29c6e82fa8f1108c7dd9fa7d3e9ce45
--- /dev/null
+++ b/unimernet/common/dist_utils.py
@@ -0,0 +1,137 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def init_distributed_mode(args):
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}, world {}): {}".format(
+ args.rank, args.world_size, args.dist_url
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ timeout=datetime.timedelta(
+ days=365
+ ), # allow auto-downloading and de-compressing
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def get_dist_info():
+ if torch.__version__ < "1.0":
+ initialized = dist._initialized
+ else:
+ initialized = dist.is_initialized()
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else: # non-distributed training
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def main_process(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ """
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+ """
+
+ def get_cached_file_path():
+ # a hack to sync the file path across processes
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+ return cached_file
+
+ if is_main_process():
+ timm_hub.download_cached_file(url, check_hash, progress)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return get_cached_file_path()
diff --git a/unimernet/common/gradcam.py b/unimernet/common/gradcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0
--- /dev/null
+++ b/unimernet/common/gradcam.py
@@ -0,0 +1,24 @@
+import numpy as np
+from matplotlib import pyplot as plt
+from scipy.ndimage import filters
+from skimage import transform as skimage_transform
+
+
+def getAttMap(img, attMap, blur=True, overlap=True):
+ attMap -= attMap.min()
+ if attMap.max() > 0:
+ attMap /= attMap.max()
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
+ if blur:
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
+ attMap -= attMap.min()
+ attMap /= attMap.max()
+ cmap = plt.get_cmap("jet")
+ attMapV = cmap(attMap)
+ attMapV = np.delete(attMapV, 3, 2)
+ if overlap:
+ attMap = (
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
+ )
+ return attMap
diff --git a/unimernet/common/logger.py b/unimernet/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d46267ed367996f17dc5a3df80e8bdb20b76af
--- /dev/null
+++ b/unimernet/common/logger.py
@@ -0,0 +1,195 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import time
+from collections import defaultdict, deque
+
+import torch
+import torch.distributed as dist
+
+from unimernet.common import dist_utils
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not dist_utils.is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ handlers=[logging.StreamHandler()],
+ )
diff --git a/unimernet/common/optims.py b/unimernet/common/optims.py
new file mode 100644
index 0000000000000000000000000000000000000000..148b5a2c30520ae3e0e033142300ba90703c6939
--- /dev/null
+++ b/unimernet/common/optims.py
@@ -0,0 +1,120 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import math
+
+from unimernet.common.registry import registry
+
+
+@registry.register_lr_scheduler("linear_warmup_step_lr")
+class LinearWarmupStepLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ min_lr,
+ init_lr,
+ decay_rate=1,
+ warmup_start_lr=-1,
+ warmup_steps=0,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.min_lr = min_lr
+
+ self.decay_rate = decay_rate
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ if cur_epoch == 0:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ step_lr_schedule(
+ epoch=cur_epoch,
+ optimizer=self.optimizer,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ decay_rate=self.decay_rate,
+ )
+
+
+@registry.register_lr_scheduler("linear_warmup_cosine_lr")
+class LinearWarmupCosineLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ min_lr,
+ init_lr,
+ iters_per_epoch,
+ warmup_steps=0,
+ warmup_start_lr=-1,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.min_lr = min_lr
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.iters_per_epoch = iters_per_epoch
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ # assuming the warmup iters less than one epoch
+ total_steps = cur_epoch * self.iters_per_epoch + cur_step
+ if total_steps < self.warmup_steps:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ cosine_lr_schedule(
+ epoch=total_steps,
+ optimizer=self.optimizer,
+ max_epoch=self.max_epoch * self.iters_per_epoch,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ )
+
+
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+ """Decay the learning rate"""
+ lr = (init_lr - min_lr) * 0.5 * (
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
+ ) + min_lr
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+ """Warmup the learning rate"""
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+ """Decay the learning rate"""
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
diff --git a/unimernet/common/registry.py b/unimernet/common/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..699c1bc137ea422e3ecde40d3fade83b0bac45f0
--- /dev/null
+++ b/unimernet/common/registry.py
@@ -0,0 +1,329 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+
+class Registry:
+ mapping = {
+ "builder_name_mapping": {},
+ "task_name_mapping": {},
+ "processor_name_mapping": {},
+ "model_name_mapping": {},
+ "lr_scheduler_name_mapping": {},
+ "runner_name_mapping": {},
+ "state": {},
+ "paths": {},
+ }
+
+ @classmethod
+ def register_builder(cls, name):
+ r"""Register a dataset builder to registry with key 'name'
+
+ Args:
+ name: Key with which the builder will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ from unimernet.datasets.base_dataset_builder import BaseDatasetBuilder
+ """
+
+ def wrap(builder_cls):
+ from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+
+ assert issubclass(
+ builder_cls, BaseDatasetBuilder
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
+ builder_cls
+ )
+ if name in cls.mapping["builder_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["builder_name_mapping"][name]
+ )
+ )
+ cls.mapping["builder_name_mapping"][name] = builder_cls
+ return builder_cls
+
+ return wrap
+
+ @classmethod
+ def register_task(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+
+ def wrap(task_cls):
+ from unimernet.tasks.base_task import BaseTask
+
+ assert issubclass(
+ task_cls, BaseTask
+ ), "All tasks must inherit BaseTask class"
+ if name in cls.mapping["task_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["task_name_mapping"][name]
+ )
+ )
+ cls.mapping["task_name_mapping"][name] = task_cls
+ return task_cls
+
+ return wrap
+
+ @classmethod
+ def register_model(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+
+ def wrap(model_cls):
+ from unimernet.models import BaseModel
+
+ assert issubclass(
+ model_cls, BaseModel
+ ), "All models must inherit BaseModel class"
+ if name in cls.mapping["model_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["model_name_mapping"][name]
+ )
+ )
+ cls.mapping["model_name_mapping"][name] = model_cls
+ return model_cls
+
+ return wrap
+
+ @classmethod
+ def register_processor(cls, name):
+ r"""Register a processor to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+
+ def wrap(processor_cls):
+ from unimernet.processors import BaseProcessor
+
+ assert issubclass(
+ processor_cls, BaseProcessor
+ ), "All processors must inherit BaseProcessor class"
+ if name in cls.mapping["processor_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["processor_name_mapping"][name]
+ )
+ )
+ cls.mapping["processor_name_mapping"][name] = processor_cls
+ return processor_cls
+
+ return wrap
+
+ @classmethod
+ def register_lr_scheduler(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+
+ def wrap(lr_sched_cls):
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
+ )
+ )
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
+ return lr_sched_cls
+
+ return wrap
+
+ @classmethod
+ def register_runner(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+
+ def wrap(runner_cls):
+ if name in cls.mapping["runner_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["runner_name_mapping"][name]
+ )
+ )
+ cls.mapping["runner_name_mapping"][name] = runner_cls
+ return runner_cls
+
+ return wrap
+
+ @classmethod
+ def register_path(cls, name, path):
+ r"""Register a path to registry with key 'name'
+
+ Args:
+ name: Key with which the path will be registered.
+
+ Usage:
+
+ from unimernet.common.registry import registry
+ """
+ assert isinstance(path, str), "All path must be str."
+ if name in cls.mapping["paths"]:
+ raise KeyError("Name '{}' already registered.".format(name))
+ cls.mapping["paths"][name] = path
+
+ @classmethod
+ def register(cls, name, obj):
+ r"""Register an item to registry with key 'name'
+
+ Args:
+ name: Key with which the item will be registered.
+
+ Usage::
+
+ from unimernet.common.registry import registry
+
+ registry.register("config", {})
+ """
+ path = name.split(".")
+ current = cls.mapping["state"]
+
+ for part in path[:-1]:
+ if part not in current:
+ current[part] = {}
+ current = current[part]
+
+ current[path[-1]] = obj
+
+ # @classmethod
+ # def get_trainer_class(cls, name):
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_builder_class(cls, name):
+ return cls.mapping["builder_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_model_class(cls, name):
+ return cls.mapping["model_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_task_class(cls, name):
+ return cls.mapping["task_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_processor_class(cls, name):
+ return cls.mapping["processor_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_lr_scheduler_class(cls, name):
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_runner_class(cls, name):
+ return cls.mapping["runner_name_mapping"].get(name, None)
+
+ @classmethod
+ def list_runners(cls):
+ return sorted(cls.mapping["runner_name_mapping"].keys())
+
+ @classmethod
+ def list_models(cls):
+ return sorted(cls.mapping["model_name_mapping"].keys())
+
+ @classmethod
+ def list_tasks(cls):
+ return sorted(cls.mapping["task_name_mapping"].keys())
+
+ @classmethod
+ def list_processors(cls):
+ return sorted(cls.mapping["processor_name_mapping"].keys())
+
+ @classmethod
+ def list_lr_schedulers(cls):
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
+
+ @classmethod
+ def list_datasets(cls):
+ return sorted(cls.mapping["builder_name_mapping"].keys())
+
+ @classmethod
+ def get_path(cls, name):
+ return cls.mapping["paths"].get(name, None)
+
+ @classmethod
+ def get(cls, name, default=None, no_warning=False):
+ r"""Get an item from registry with key 'name'
+
+ Args:
+ name (string): Key whose value needs to be retrieved.
+ default: If passed and key is not in registry, default value will
+ be returned with a warning. Default: None
+ no_warning (bool): If passed as True, warning when key doesn't exist
+ will not be generated. Useful for MMF's
+ internal operations. Default: False
+ """
+ original_name = name
+ name = name.split(".")
+ value = cls.mapping["state"]
+ for subname in name:
+ value = value.get(subname, default)
+ if value is default:
+ break
+
+ if (
+ "writer" in cls.mapping["state"]
+ and value == default
+ and no_warning is False
+ ):
+ cls.mapping["state"]["writer"].warning(
+ "Key {} is not present in registry, returning default value "
+ "of {}".format(original_name, default)
+ )
+ return value
+
+ @classmethod
+ def unregister(cls, name):
+ r"""Remove an item from registry with key 'name'
+
+ Args:
+ name: Key which needs to be removed.
+ Usage::
+
+ from mmf.common.registry import registry
+
+ config = registry.unregister("config")
+ """
+ return cls.mapping["state"].pop(name, None)
+
+
+registry = Registry()
diff --git a/unimernet/common/utils.py b/unimernet/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6c3366b18db21db5330e85ab4e239a404312b86
--- /dev/null
+++ b/unimernet/common/utils.py
@@ -0,0 +1,424 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import io
+import json
+import logging
+import os
+import pickle
+import re
+import shutil
+import urllib
+import urllib.error
+import urllib.request
+from typing import Optional
+from urllib.parse import urlparse
+
+import numpy as np
+import pandas as pd
+import yaml
+from iopath.common.download import download
+from iopath.common.file_io import file_lock, g_pathmgr
+from unimernet.common.registry import registry
+from torch.utils.model_zoo import tqdm
+from torchvision.datasets.utils import (
+ check_integrity,
+ download_file_from_google_drive,
+ extract_archive,
+)
+
+
+def now():
+ from datetime import datetime
+
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def get_cache_path(rel_path):
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
+
+
+def get_abs_path(rel_path):
+ return os.path.join(registry.get_path("library_root"), rel_path)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+# The following are adapted from torchvision and vissl
+# torchvision: https://github.com/pytorch/vision
+# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ print(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def get_redirected_url(url: str):
+ """
+ Given a URL, returns the URL it redirects to or the
+ original URL in case of no indirection
+ """
+ import requests
+
+ with requests.Session() as session:
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ if response.history:
+ return response.url
+ else:
+ return url
+
+
+def to_google_drive_download_url(view_url: str) -> str:
+ """
+ Utility function to transform a view URL of google drive
+ to a download URL for google drive
+ Example input:
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
+ Example output:
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
+ """
+ splits = view_url.split("/")
+ assert splits[-1] == "view"
+ file_id = splits[-2]
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
+
+
+def download_google_drive_url(url: str, output_path: str, output_file_name: str):
+ """
+ Download a file from google drive
+ Downloading an URL from google drive requires confirmation when
+ the file of the size is too big (google drive notifies that
+ anti-viral checks cannot be performed on such files)
+ """
+ import requests
+
+ with requests.Session() as session:
+
+ # First get the confirmation token and append it to the URL
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ for k, v in response.cookies.items():
+ if k.startswith("download_warning"):
+ url = url + "&confirm=" + v
+
+ # Then download the content of the file
+ with session.get(url, stream=True, verify=True) as response:
+ makedir(output_path)
+ path = os.path.join(output_path, output_file_name)
+ total_size = int(response.headers.get("Content-length", 0))
+ with open(path, "wb") as file:
+ from tqdm import tqdm
+
+ with tqdm(total=total_size) as progress_bar:
+ for block in response.iter_content(
+ chunk_size=io.DEFAULT_BUFFER_SIZE
+ ):
+ file.write(block)
+ progress_bar.update(len(block))
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+ parts = urlparse(url)
+
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+ return None
+
+ match = re.match(r"/file/d/(?P[^/]*)", parts.path)
+ if match is None:
+ return None
+
+ return match.group("id")
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
+ with open(filename, "wb") as fh:
+ with urllib.request.urlopen(
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
+ ) as response:
+ with tqdm(total=response.length) as pbar:
+ for chunk in iter(lambda: response.read(chunk_size), ""):
+ if not chunk:
+ break
+ pbar.update(chunk_size)
+ fh.write(chunk)
+
+
+def download_url(
+ url: str,
+ root: str,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+) -> None:
+ """Download a file from a url and place it in root.
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under.
+ If None, use the basename of the URL.
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+
+ makedir(root)
+
+ # check if file is already present locally
+ if check_integrity(fpath, md5):
+ print("Using downloaded and verified file: " + fpath)
+ return
+
+ # expand redirect chain if needed
+ url = get_redirected_url(url)
+
+ # check if file is located on Google Drive
+ file_id = _get_google_drive_file_id(url)
+ if file_id is not None:
+ return download_file_from_google_drive(file_id, root, filename, md5)
+
+ # download the file
+ try:
+ print("Downloading " + url + " to " + fpath)
+ _urlretrieve(url, fpath)
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
+ if url[:5] == "https":
+ url = url.replace("https:", "http:")
+ print(
+ "Failed download. Trying https -> http instead."
+ " Downloading " + url + " to " + fpath
+ )
+ _urlretrieve(url, fpath)
+ else:
+ raise e
+
+ # check integrity of downloaded file
+ if not check_integrity(fpath, md5):
+ raise RuntimeError("File not found or corrupted.")
+
+
+def download_and_extract_archive(
+ url: str,
+ download_root: str,
+ extract_root: Optional[str] = None,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+ remove_finished: bool = False,
+) -> None:
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if not filename:
+ filename = os.path.basename(url)
+
+ download_url(url, download_root, filename, md5)
+
+ archive = os.path.join(download_root, filename)
+ print("Extracting {} to {}".format(archive, extract_root))
+ extract_archive(archive, extract_root, remove_finished)
+
+
+def cache_url(url: str, cache_dir: str) -> str:
+ """
+ This implementation downloads the remote resource and caches it locally.
+ The resource will only be downloaded if not previously requested.
+ """
+ parsed_url = urlparse(url)
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
+ makedir(dirname)
+ filename = url.split("/")[-1]
+ cached = os.path.join(dirname, filename)
+ with file_lock(cached):
+ if not os.path.isfile(cached):
+ logging.info(f"Downloading {url} to {cached} ...")
+ cached = download(url, dirname, filename=filename)
+ logging.info(f"URL {url} cached in {cached}")
+ return cached
+
+
+# TODO (prigoyal): convert this into RAII-style API
+def create_file_symlink(file1, file2):
+ """
+ Simply create the symlinks for a given file1 to file2.
+ Useful during model checkpointing to symlinks to the
+ latest successful checkpoint.
+ """
+ try:
+ if g_pathmgr.exists(file2):
+ g_pathmgr.rm(file2)
+ g_pathmgr.symlink(file1, file2)
+ except Exception as e:
+ logging.info(f"Could NOT create symlink. Error: {e}")
+
+
+def save_file(data, filename, append_to_json=True, verbose=True):
+ """
+ Common i/o utility to handle saving data to various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ Specifically for .json, users have the option to either append (default)
+ or rewrite by passing in Boolean value to append_to_json.
+ """
+ if verbose:
+ logging.info(f"Saving data to file: {filename}")
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "wb") as fopen:
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
+ elif file_ext == ".npy":
+ with g_pathmgr.open(filename, "wb") as fopen:
+ np.save(fopen, data)
+ elif file_ext == ".json":
+ if append_to_json:
+ with g_pathmgr.open(filename, "a") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ else:
+ with g_pathmgr.open(filename, "w") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "w") as fopen:
+ dump = yaml.dump(data)
+ fopen.write(dump)
+ fopen.flush()
+ else:
+ raise Exception(f"Saving {file_ext} is not supported yet")
+
+ if verbose:
+ logging.info(f"Saved data to file: {filename}")
+
+
+def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
+ """
+ Common i/o utility to handle loading data from various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ For the npy files, we support reading the files in mmap_mode.
+ If the mmap_mode of reading is not successful, we load data without the
+ mmap_mode.
+ """
+ if verbose:
+ logging.info(f"Loading data from file: {filename}")
+
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext == ".txt":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = fopen.readlines()
+ elif file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = pickle.load(fopen, encoding="latin1")
+ elif file_ext == ".npy":
+ if mmap_mode:
+ try:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(
+ fopen,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ except ValueError as e:
+ logging.info(
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
+ )
+ data = np.load(
+ filename,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ logging.info("Successfully loaded without g_pathmgr")
+ except Exception:
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ else:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ elif file_ext == ".json":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = json.load(fopen)
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
+ elif file_ext == ".csv":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = pd.read_csv(fopen)
+ else:
+ raise Exception(f"Reading from {file_ext} is not supported yet")
+ return data
+
+
+def abspath(resource_path: str):
+ """
+ Make a path absolute, but take into account prefixes like
+ "http://" or "manifold://"
+ """
+ regex = re.compile(r"^\w+://")
+ if regex.match(resource_path) is None:
+ return os.path.abspath(resource_path)
+ else:
+ return resource_path
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ logging.info(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def is_url(input_url):
+ """
+ Check if an input string is a url. look for http(s):// and ignoring the case
+ """
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
+ return is_url
+
+
+def cleanup_dir(dir):
+ """
+ Utility for deleting a directory. Useful for cleaning the storage space
+ that contains various training artifacts like checkpoints, data etc.
+ """
+ if os.path.exists(dir):
+ logging.info(f"Deleting directory: {dir}")
+ shutil.rmtree(dir)
+ logging.info(f"Deleted contents of directory: {dir}")
+
+
+def get_file_size(filename):
+ """
+ Given a file, get the size of file in MB
+ """
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
+ return size_in_mb
diff --git a/unimernet/configs/datasets/formula/formula_eval.yaml b/unimernet/configs/datasets/formula/formula_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6e7e1ed2acd650c5a2224c77cc64bf67d5d62e5
--- /dev/null
+++ b/unimernet/configs/datasets/formula/formula_eval.yaml
@@ -0,0 +1,6 @@
+datasets:
+ formula_rec_eval:
+ data_type: images
+ build_info:
+ images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/val
+ annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
\ No newline at end of file
diff --git a/unimernet/configs/datasets/formula/formula_train.yaml b/unimernet/configs/datasets/formula/formula_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aa4af4cf3464521f2ac30686e087a01885c03741
--- /dev/null
+++ b/unimernet/configs/datasets/formula/formula_train.yaml
@@ -0,0 +1,6 @@
+datasets:
+ formula_rec_train:
+ data_type: images
+ build_info:
+ images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/train
+ annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
\ No newline at end of file
diff --git a/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml b/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c6dc5058ba06593d16f67882e2600acdcd9116b
--- /dev/null
+++ b/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml
@@ -0,0 +1,21 @@
+datasets:
+ multi_scale_formula_rec_train:
+ data_type: images
+ build_info:
+ images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/train
+ annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
+
+ vis_processor:
+ train:
+ name: "formula_image_multi_scale_train"
+ all_scales:
+ - [ 96, 336 ]
+ - [ 128, 448 ]
+ - [ 192, 672 ]
+ - [ 288, 1008 ]
+ - [ 384, 1344 ]
+
+ text_processor:
+ train:
+ name: "blip_caption"
+ max_words: 256
\ No newline at end of file
diff --git a/unimernet/configs/default.yaml b/unimernet/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4800a0e0d4a0444db40558ba950e08a33d80d31
--- /dev/null
+++ b/unimernet/configs/default.yaml
@@ -0,0 +1,10 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+env:
+ # For default users
+ # cache_root: "cache"
+ # For internal use with persistent storage
+ cache_root: "/export/home/.cache/vigc"
diff --git a/unimernet/configs/models/unimernet_base.yaml b/unimernet/configs/models/unimernet_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24e31350f30b8e0b8d4fc2331cc1f0319c33f349
--- /dev/null
+++ b/unimernet/configs/models/unimernet_base.yaml
@@ -0,0 +1,31 @@
+model:
+ arch: unimernet
+ load_finetuned: False
+ load_pretrained: False
+ pretrained: "path/to/pretrained/weight"
+ finetuned: ""
+ tokenizer_name: nougat
+ tokenizer_config:
+ path: ./models/unimernet
+ model_name: unimernet
+ model_config:
+ max_seq_len: 384
+
+
+preprocess:
+ vis_processor:
+ train:
+ name: "formula_image_train"
+ image_size:
+ - 192
+ - 672
+ eval:
+ name: "formula_image_eval"
+ image_size:
+ - 192
+ - 672
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
diff --git a/unimernet/datasets/__init__.py b/unimernet/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/datasets/__pycache__/__init__.cpython-310.pyc b/unimernet/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f42879cec833c2ad1436195309ddbe4cd93beeaf
Binary files /dev/null and b/unimernet/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc b/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da626f1b4f156471999aca4475572beddd9b19ac
Binary files /dev/null and b/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__init__.py b/unimernet/datasets/builders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9ad64e5aa54e4396ee752b2f1a1aa980f254f8
--- /dev/null
+++ b/unimernet/datasets/builders/__init__.py
@@ -0,0 +1,69 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.datasets.builders.base_dataset_builder import load_dataset_config
+from unimernet.common.registry import registry
+from unimernet.datasets.builders.formula import FormulaRecTrainBuilder, FormulaRecEvalBuilder, \
+ MultiScaleFormulaRecTrainBuilder
+
+__all__ = [
+ "FormulaRecTrainBuilder",
+ "FormulaRecEvalBuilder",
+ "MultiScaleFormulaRecTrainBuilder",
+]
+
+
+def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
+ """
+ Example
+
+ >>> dataset = load_dataset("coco_caption", cfg=None)
+ >>> splits = dataset.keys()
+ >>> print([len(dataset[split]) for split in splits])
+
+ """
+ if cfg_path is None:
+ cfg = None
+ else:
+ cfg = load_dataset_config(cfg_path)
+
+ try:
+ builder = registry.get_builder_class(name)(cfg)
+ except TypeError:
+ print(
+ f"Dataset {name} not found. Available datasets:\n"
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
+ )
+ exit(1)
+
+ if vis_path is not None:
+ if data_type is None:
+ # use default data type in the config
+ data_type = builder.config.data_type
+
+ assert (
+ data_type in builder.config.build_info
+ ), f"Invalid data_type {data_type} for {name}."
+
+ builder.config.build_info.get(data_type).storage = vis_path
+
+ dataset = builder.build_datasets()
+ return dataset
+
+
+class DatasetZoo:
+ def __init__(self) -> None:
+ self.dataset_zoo = {
+ k: list(v.DATASET_CONFIG_DICT.keys())
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
+ }
+
+ def get_names(self):
+ return list(self.dataset_zoo.keys())
+
+
+dataset_zoo = DatasetZoo()
diff --git a/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d38a0ca49e1d07a0295d7a6eeb9a4f5b43faa29
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02302e0d85d678b9f40452cf49f01dad6bbf614
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef951c05619e26e0164a7d7d2a58b5b0dbdfd3c3
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/base_dataset_builder.py b/unimernet/datasets/builders/base_dataset_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fdc086214ee9c9b92143f6233a4dc3c8a3ec06d
--- /dev/null
+++ b/unimernet/datasets/builders/base_dataset_builder.py
@@ -0,0 +1,233 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+import shutil
+import warnings
+
+import unimernet.common.utils as utils
+import torch.distributed as dist
+from unimernet.common.dist_utils import is_dist_avail_and_initialized, is_main_process
+from unimernet.common.registry import registry
+from unimernet.processors.base_processor import BaseProcessor
+from omegaconf import OmegaConf
+from torchvision.datasets.utils import download_url
+
+
+class BaseDatasetBuilder:
+ train_dataset_cls, eval_dataset_cls = None, None
+
+ def __init__(self, cfg=None):
+ super().__init__()
+
+ if cfg is None:
+ # help to create datasets from default config.
+ self.config = load_dataset_config(self.default_config_path())
+ elif isinstance(cfg, str):
+ self.config = load_dataset_config(cfg)
+ else:
+ # when called from task.build_dataset()
+ self.config = cfg
+
+ self.data_type = self.config.data_type
+
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+
+ def build_datasets(self):
+ # download, split, etc...
+ # only called on 1 GPU/TPU in distributed
+
+ if is_main_process():
+ self._download_data()
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ datasets = self.build() # dataset['train'/'val'/'test']
+
+ return datasets
+
+ def build_processors(self):
+ vis_proc_cfg = self.config.get("vis_processor")
+ txt_proc_cfg = self.config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
+
+ @staticmethod
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else None
+ )
+
+ @classmethod
+ def default_config_path(cls, type="default"):
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
+
+ def _download_data(self):
+ self._download_ann()
+ self._download_vis()
+
+ def _download_ann(self):
+ """
+ Download annotation files if necessary.
+ All the vision-language datasets should have annotations of unified format.
+
+ storage_path can be:
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
+
+ Local annotation paths should be relative.
+ """
+ anns = self.config.build_info.annotations
+
+ splits = anns.keys()
+
+ cache_root = registry.get_path("cache_root")
+
+ for split in splits:
+ info = anns[split]
+
+ urls, storage_paths = info.get("url", None), info.storage
+
+ if isinstance(urls, str):
+ urls = [urls]
+ if isinstance(storage_paths, str):
+ storage_paths = [storage_paths]
+
+ assert len(urls) == len(storage_paths)
+
+ for url_or_filename, storage_path in zip(urls, storage_paths):
+ # if storage_path is relative, make it full by prefixing with cache_root.
+ if not os.path.isabs(storage_path):
+ storage_path = os.path.join(cache_root, storage_path)
+
+ dirname = os.path.dirname(storage_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ if os.path.isfile(url_or_filename):
+ src, dst = url_or_filename, storage_path
+ if not os.path.exists(dst):
+ shutil.copyfile(src=src, dst=dst)
+ else:
+ logging.info("Using existing file {}.".format(dst))
+ else:
+ if os.path.isdir(storage_path):
+ # if only dirname is provided, suffix with basename of URL.
+ raise ValueError(
+ "Expecting storage_path to be a file path, got directory {}".format(
+ storage_path
+ )
+ )
+ else:
+ filename = os.path.basename(storage_path)
+
+ download_url(url=url_or_filename, root=dirname, filename=filename)
+
+ def _download_vis(self):
+
+ storage_path = self.config.build_info.get(self.data_type).storage
+ storage_path = utils.get_cache_path(storage_path)
+
+ if not os.path.exists(storage_path):
+ warnings.warn(
+ f"""
+ The specified path {storage_path} for visual inputs does not exist.
+ Please provide a correct path to the visual inputs or
+ refer to datasets/download_scripts/README.md for downloading instructions.
+ """
+ )
+
+ def build(self):
+ """
+ Create by split datasets inheriting torch.utils.data.Datasets.
+
+ # build() can be dataset-specific. Overwrite to customize.
+ """
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ ann_info = build_info.annotations
+ vis_info = build_info.get(self.data_type)
+
+ datasets = dict()
+ for split in ann_info.keys():
+ if split not in ["train", "val", "test"]:
+ continue
+
+ is_train = split == "train"
+
+ # processors
+ vis_processor = (
+ self.vis_processors["train"]
+ if is_train
+ else self.vis_processors["eval"]
+ )
+ text_processor = (
+ self.text_processors["train"]
+ if is_train
+ else self.text_processors["eval"]
+ )
+
+ # annotation path
+ ann_paths = ann_info.get(split).storage
+ if isinstance(ann_paths, str):
+ ann_paths = [ann_paths]
+
+ abs_ann_paths = []
+ for ann_path in ann_paths:
+ if not os.path.isabs(ann_path):
+ ann_path = utils.get_cache_path(ann_path)
+ abs_ann_paths.append(ann_path)
+ ann_paths = abs_ann_paths
+
+ # visual data storage path
+ vis_path = vis_info.storage
+
+ if not os.path.isabs(vis_path):
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
+ vis_path = utils.get_cache_path(vis_path)
+
+ if not os.path.exists(vis_path):
+ warnings.warn("storage path {} does not exist.".format(vis_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=vis_processor,
+ text_processor=text_processor,
+ ann_paths=ann_paths,
+ vis_root=vis_path,
+ )
+
+ return datasets
+
+
+def load_dataset_config(cfg_path):
+ cfg = OmegaConf.load(cfg_path).datasets
+ cfg = cfg[list(cfg.keys())[0]]
+
+ return cfg
diff --git a/unimernet/datasets/builders/formula.py b/unimernet/datasets/builders/formula.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e3af7568eaee4ef3ec856a7899ae6243cb5499
--- /dev/null
+++ b/unimernet/datasets/builders/formula.py
@@ -0,0 +1,105 @@
+import logging
+from unimernet.common.registry import registry
+from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from unimernet.datasets.datasets.formula import Im2LatexDataset
+from unimernet.datasets.datasets.formula_multi_scale import MultiScaleIm2LatexDataset
+
+
+@registry.register_builder("formula_rec_train")
+class FormulaRecTrainBuilder(BaseDatasetBuilder):
+ train_dataset_cls = Im2LatexDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/formula/formula_train.yaml"
+ }
+ LOG_INFO = "Formula Recgnition Train"
+
+ def build_datasets(self):
+ logging.info(f"Building {self.LOG_INFO} datasets ...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ anno_path = build_info.annotation,
+ vis_root = build_info.images
+ anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+ vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ vis_root=vis_root,
+ anno_path=anno_path,
+ )
+ print(datasets['train'][0])
+
+ return datasets
+
+
+@registry.register_builder("multi_scale_formula_rec_train")
+class MultiScaleFormulaRecTrainBuilder(BaseDatasetBuilder):
+ train_dataset_cls = MultiScaleIm2LatexDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/formula/multi_scale_formula_train.yaml"
+ }
+ LOG_INFO = "Multi Scale Formula Recgnition Train"
+
+ def build_datasets(self):
+ logging.info(f"Building {self.LOG_INFO} datasets ...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ anno_path = build_info.annotation,
+ vis_root = build_info.images
+
+ anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+ vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ vis_root=vis_root,
+ anno_path=anno_path,
+ )
+ print(datasets['train'][0])
+
+ return datasets
+
+
+@registry.register_builder("formula_rec_eval")
+class FormulaRecEvalBuilder(BaseDatasetBuilder):
+ eval_dataset_cls = Im2LatexDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/formula/formula_eval.yaml"
+ }
+ LOG_INFO = "Formula Recgnition Eval"
+
+ def build_datasets(self):
+ logging.info(f"Building {self.LOG_INFO} datasets ...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ anno_path = build_info.annotation,
+ vis_root = build_info.images
+
+ anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+ vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.eval_dataset_cls
+ datasets['eval'] = dataset_cls(
+ vis_processor=self.vis_processors["eval"],
+ text_processor=self.text_processors["eval"],
+ vis_root=vis_root,
+ anno_path=anno_path,
+ )
+ print(datasets['eval'][0])
+
+ return datasets
diff --git a/unimernet/datasets/data_utils.py b/unimernet/datasets/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e82c90ebf4fa8a094b0434c114c566e2c8b7d61
--- /dev/null
+++ b/unimernet/datasets/data_utils.py
@@ -0,0 +1,284 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import gzip
+import logging
+import os
+import random as rnd
+import tarfile
+import zipfile
+
+import decord
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data.dataset import IterableDataset, ChainDataset
+from decord import VideoReader
+from unimernet.common.registry import registry
+from unimernet.datasets.datasets.base_dataset import ConcatDataset
+from tqdm import tqdm
+
+decord.bridge.set_bridge("torch")
+MAX_INT = registry.get("MAX_INT")
+
+
+def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"):
+ vr = VideoReader(uri=video_path, height=height, width=width)
+
+ vlen = len(vr)
+ start, end = 0, vlen
+
+ n_frms = min(n_frms, vlen)
+
+ if sampling == "uniform":
+ indices = np.arange(start, end, vlen / n_frms).astype(int)
+ elif sampling == "headtail":
+ indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
+ indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
+ indices = indices_h + indices_t
+ else:
+ raise NotImplementedError
+
+ # get_batch -> T, H, W, C
+ frms = vr.get_batch(indices).permute(3, 0, 1, 2).float() # (C, T, H, W)
+
+ return frms
+
+
+def apply_to_sample(f, sample):
+ if len(sample) == 0:
+ return {}
+
+ def _apply(x):
+ if torch.is_tensor(x):
+ return f(x)
+ elif isinstance(x, dict):
+ return {key: _apply(value) for key, value in x.items()}
+ elif isinstance(x, list):
+ return [_apply(x) for x in x]
+ else:
+ return x
+
+ return _apply(sample)
+
+
+def move_to_cuda(sample):
+ def _move_to_cuda(tensor):
+ return tensor.cuda()
+
+ return apply_to_sample(_move_to_cuda, sample)
+
+
+def prepare_sample(samples, cuda_enabled=True):
+ if cuda_enabled:
+ samples = move_to_cuda(samples)
+
+ # TODO fp16 support
+
+ return samples
+
+
+def reorg_datasets_by_split(datasets):
+ """
+ Organizes datasets by split.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by name.
+
+ Returns:
+ Dict of datasets by split {split_name: List[Datasets]}.
+ """
+ # if len(datasets) == 1:
+ # return datasets[list(datasets.keys())[0]]
+ # else:
+ reorg_datasets = dict()
+
+ # reorganize by split
+ for _, dataset in datasets.items():
+ for split_name, dataset_split in dataset.items():
+ if split_name not in reorg_datasets:
+ reorg_datasets[split_name] = [dataset_split]
+ else:
+ reorg_datasets[split_name].append(dataset_split)
+
+ return reorg_datasets
+
+
+def concat_datasets(datasets):
+ """
+ Concatenates multiple datasets into a single dataset.
+
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
+ generic IterableDataset because it requires creating separate samplers.
+
+ Now only supports conctenating training datasets and assuming validation and testing
+ have only a single dataset. This is because metrics should not be computed on the concatenated
+ datasets.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by split.
+
+ Returns:
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
+ "val" and "test" remain the same.
+
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
+ a tuple, where the first element is a concatenated map-style dataset and the second
+ element is a chained DataPipeline dataset.
+
+ """
+ # concatenate datasets in the same split
+ for split_name in datasets:
+ if split_name != "train":
+ assert (
+ len(datasets[split_name]) == 1
+ ), "Do not support multiple {} datasets.".format(split_name)
+ datasets[split_name] = datasets[split_name][0]
+ else:
+ iterable_datasets, map_datasets = [], []
+ for dataset in datasets[split_name]:
+ if isinstance(dataset, wds.DataPipeline):
+ logging.info(
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
+ dataset
+ )
+ )
+ iterable_datasets.append(dataset)
+ elif isinstance(dataset, IterableDataset):
+ raise NotImplementedError(
+ "Do not support concatenation of generic IterableDataset."
+ )
+ else:
+ map_datasets.append(dataset)
+
+ # if len(iterable_datasets) > 0:
+ # concatenate map-style datasets and iterable-style datasets separately
+ chained_datasets = (
+ ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None
+ )
+ concat_datasets = (
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
+ )
+
+ train_datasets = concat_datasets, chained_datasets
+ train_datasets = tuple([x for x in train_datasets if x is not None])
+ train_datasets = (
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
+ )
+
+ datasets[split_name] = train_datasets
+
+ return datasets
+
+
+def extract_archive(from_path, to_path=None, overwrite=False):
+ """Extract archive.
+
+ Args:
+ from_path: the path of the archive.
+ to_path: the root path of the extracted files (directory of from_path)
+ overwrite: overwrite existing files (False)
+
+ Returns:
+ List of paths to extracted files even if not overwritten.
+
+ Examples:
+ >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
+ >>> from_path = './validation.tar.gz'
+ >>> to_path = './'
+ >>> torchtext.utils.download_from_url(url, from_path)
+ >>> torchtext.utils.extract_archive(from_path, to_path)
+ >>> ['.data/val.de', '.data/val.en']
+ >>> torchtext.utils.download_from_url(url, from_path)
+ >>> torchtext.utils.extract_archive(from_path, to_path)
+ >>> ['.data/val.de', '.data/val.en']
+
+ """
+
+ if to_path is None:
+ to_path = os.path.dirname(from_path)
+
+ if from_path.endswith((".tar.gz", ".tgz")):
+ logging.info("Opening tar file {} to {}.".format(from_path, to_path))
+ with tarfile.open(from_path, "r") as tar:
+ files = []
+ for file_ in tqdm(tar):
+ file_path = os.path.join(to_path, file_.name)
+ if file_.isfile():
+ files.append(file_path)
+ if os.path.exists(file_path):
+ logging.info("{} already extracted.".format(file_path))
+ if not overwrite:
+ continue
+ tar.extract(file_, to_path)
+ logging.info("Finished extracting tar file {}.".format(from_path))
+ return files
+
+ elif from_path.endswith(".zip"):
+ assert zipfile.is_zipfile(from_path), from_path
+ logging.info("Opening zip file {} to {}.".format(from_path, to_path))
+ with zipfile.ZipFile(from_path, "r") as zfile:
+ files = []
+ for file_ in tqdm(zfile.namelist()):
+ file_path = os.path.join(to_path, file_)
+ files.append(file_path)
+ if os.path.exists(file_path):
+ logging.info("{} already extracted.".format(file_path))
+ if not overwrite:
+ continue
+ zfile.extract(file_, to_path)
+ files = [f for f in files if os.path.isfile(f)]
+ logging.info("Finished extracting zip file {}.".format(from_path))
+ return files
+
+ elif from_path.endswith(".gz"):
+ logging.info("Opening gz file {} to {}.".format(from_path, to_path))
+ default_block_size = 65536
+ filename = from_path[:-3]
+ files = [filename]
+ with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
+ while True:
+ block = gzfile.read(default_block_size)
+ if not block:
+ break
+ else:
+ d_file.write(block)
+ d_file.write(block)
+ logging.info("Finished extracting gz file {}.".format(from_path))
+ return files
+
+ else:
+ raise NotImplementedError(
+ "We currently only support tar.gz, .tgz, .gz and zip achives."
+ )
+
+
+def save_frames_grid(img_array, out_path):
+ import torch
+ from PIL import Image
+ from torchvision.utils import make_grid
+
+ if len(img_array.shape) == 3:
+ img_array = img_array.unsqueeze(0)
+ elif len(img_array.shape) == 5:
+ b, t, c, h, w = img_array.shape
+ img_array = img_array.view(-1, c, h, w)
+ elif len(img_array.shape) == 4:
+ pass
+ else:
+ raise NotImplementedError(
+ "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored."
+ )
+
+ assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
+
+ grid = make_grid(img_array)
+ ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+
+ img = Image.fromarray(ndarr)
+
+ img.save(out_path)
diff --git a/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b66d52b14d2c07e62d7f928930209753687b9e1
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c75254a5a71a8f172cb61aea9c6faa75439dcd
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0b9ae43b5ea5491a780bff67b6f4eb8fd0946d6
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/base_dataset.py b/unimernet/datasets/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..591ab94db2a688dc3c74d06c72a1102333411a8f
--- /dev/null
+++ b/unimernet/datasets/datasets/base_dataset.py
@@ -0,0 +1,103 @@
+import json
+from PIL import Image, ImageFile
+import os.path as osp
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+from io import BytesIO
+from typing import Iterable
+from torch.utils.data import Dataset, ConcatDataset
+import torch
+
+
+class BaseDataset(Dataset):
+
+ def __init__(self, vis_processor, text_processor, vis_root, anno_path):
+
+ self.vis_root = vis_root
+ # if isinstance(anno_path, tuple) or isinstance(anno_path, list):
+ # anno_path = anno_path[0]
+ self.anno_path = anno_path
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.samples = self.init_samples()
+ self.reader = self.init_reader()
+
+ print('total {} {} samples'.format(self.__len__(), self.__class__.__name__))
+
+ for idx in range(10):
+ self.__getitem__(idx)
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def init_samples(self):
+ # read annotation from ceph
+ if self.anno_path.startswith('cluster'):
+ from petrel_client.client import Client
+ client = Client("~/petreloss.conf")
+ samples = json.loads(client.get(self.anno_path))
+ else:
+ samples = json.load(open(self.anno_path, 'r'))
+ return samples
+
+ def init_reader(self):
+ if self.vis_root.startswith('cluster'):
+ from petrel_client.client import Client
+ client = Client("~/petreloss.conf")
+ reader = {'type': 'PetrelReader', 'body': client.get}
+ else:
+ reader = {'type': 'LocalReader', 'body': Image.open}
+ return reader
+
+ def _read_image(self, sample, image_key="image"):
+ img_file = sample[image_key]
+ image_path = osp.join(self.vis_root, img_file)
+ image = self.reader['body'](image_path)
+ if isinstance(image, bytes):
+ bytes_stream = BytesIO(image)
+ image = Image.open(bytes_stream)
+ image = image.convert("RGB")
+ return image
+
+ def collater(self, samples):
+ image_list, question_list, answer_list = [], [], []
+
+ for sample in samples:
+ image_list.append(sample["image"])
+ question_list.append(sample["text_input"])
+ answer_list.append(sample["text_output"])
+
+ return {
+ "image": torch.stack(image_list, dim=0),
+ "text_input": question_list,
+ "text_output": answer_list,
+ "data_type": "vqa",
+ }
+
+
+class ConcatDataset(ConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super().__init__(datasets)
+
+ def collater(self, samples):
+ # TODO For now only supports datasets with same underlying collater implementations
+
+ all_keys = set()
+ for s in samples:
+ all_keys.update(s)
+
+ shared_keys = all_keys
+ for s in samples:
+ shared_keys = shared_keys & set(s.keys())
+
+ samples_shared_keys = []
+ for s in samples:
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
+
+ return self.datasets[0].collater(samples_shared_keys)
diff --git a/unimernet/datasets/datasets/dataloader_utils.py b/unimernet/datasets/datasets/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab63bf9b6dce16447f816763210cf87ca3940097
--- /dev/null
+++ b/unimernet/datasets/datasets/dataloader_utils.py
@@ -0,0 +1,200 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import time
+import random
+import torch
+from unimernet.datasets.data_utils import move_to_cuda
+from torch.utils.data import DataLoader
+
+
+class MultiIterLoader:
+ """
+ A simple wrapper for iterating over multiple iterators.
+
+ Args:
+ loaders (List[Loader]): List of Iterator loaders.
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+ """
+
+ def __init__(self, loaders, ratios=None):
+ # assert all loaders has __next__ method
+ for loader in loaders:
+ assert hasattr(
+ loader, "__next__"
+ ), "Loader {} has no __next__ method.".format(loader)
+
+ if ratios is None:
+ ratios = [1.0] * len(loaders)
+ else:
+ assert len(ratios) == len(loaders)
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
+
+ self.loaders = loaders
+ self.ratios = ratios
+
+ def __next__(self):
+ # random sample from each loader by ratio
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
+ return next(self.loaders[loader_idx])
+
+ def __len__(self):
+ return sum([len(_) for _ in self.loaders if hasattr(_, "__len__")])
+
+
+class ConcatLoader:
+ """
+ A simple wrapper for iterating over multiple iterators.
+
+ Args:
+ loaders (List[Loader]): List of Iterator loaders.
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+ """
+
+ def __init__(self, loaders):
+ # assert all loaders has __next__ method
+ for loader in loaders:
+ assert hasattr(
+ loader, "__len__"
+ ), "Loader {} has no __len__ method.".format(loader)
+
+ self._epoch = 0
+ self._loader_lens = [len(_) for _ in loaders]
+ self._rest_lens = self._loader_lens.copy()
+
+ self.loaders = loaders
+
+ def __next__(self):
+ # random sample from each loader by ratio
+ loader_idx = random.choices(range(len(self.loaders)), self._rest_lens, k=1)[0]
+ self._rest_lens[loader_idx] -= 1
+ if sum(self._rest_lens) == 0:
+ self._epoch += 1
+ self._rest_lens = self._loader_lens.copy()
+ return next(self.loaders[loader_idx])
+
+ def __len__(self):
+ return sum([len(_) for _ in self.loaders if hasattr(_, "__len__")])
+
+
+class PrefetchLoader(object):
+ """
+ Modified from https://github.com/ChenRocks/UNITER.
+
+ overlap compute and cuda data transfer
+ (copied and then modified from nvidia apex)
+ """
+
+ def __init__(self, loader):
+ self.loader = loader
+ self.stream = torch.cuda.Stream()
+
+ def __iter__(self):
+ loader_it = iter(self.loader)
+ self.preload(loader_it)
+ batch = self.next(loader_it)
+ while batch is not None:
+ is_tuple = isinstance(batch, tuple)
+ if is_tuple:
+ task, batch = batch
+
+ if is_tuple:
+ yield task, batch
+ else:
+ yield batch
+ batch = self.next(loader_it)
+
+ def __len__(self):
+ return len(self.loader)
+
+ def preload(self, it):
+ try:
+ self.batch = next(it)
+ except StopIteration:
+ self.batch = None
+ return
+ # if record_stream() doesn't work, another option is to make sure
+ # device inputs are created on the main stream.
+ # self.next_input_gpu = torch.empty_like(self.next_input,
+ # device='cuda')
+ # self.next_target_gpu = torch.empty_like(self.next_target,
+ # device='cuda')
+ # Need to make sure the memory allocated for next_* is not still in use
+ # by the main stream at the time we start copying to next_*:
+ # self.stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.stream):
+ self.batch = move_to_cuda(self.batch)
+ # more code for the alternative if record_stream() doesn't work:
+ # copy_ will record the use of the pinned source tensor in this
+ # side stream.
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+ # self.next_input = self.next_input_gpu
+ # self.next_target = self.next_target_gpu
+
+ def next(self, it):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is not None:
+ record_cuda_stream(batch)
+ self.preload(it)
+ return batch
+
+ def __getattr__(self, name):
+ method = self.loader.__getattribute__(name)
+ return method
+
+
+def record_cuda_stream(batch):
+ if isinstance(batch, torch.Tensor):
+ batch.record_stream(torch.cuda.current_stream())
+ elif isinstance(batch, list) or isinstance(batch, tuple):
+ for t in batch:
+ record_cuda_stream(t)
+ elif isinstance(batch, dict):
+ for t in batch.values():
+ record_cuda_stream(t)
+ else:
+ pass
+
+
+class IterLoader:
+ """
+ A wrapper to convert DataLoader as an infinite iterator.
+
+ Modified from:
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
+ """
+
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._use_distributed = use_distributed
+ self._epoch = 0
+
+ @property
+ def epoch(self) -> int:
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __iter__(self):
+ return self
+
+ def __len__(self):
+ return len(self._dataloader)
diff --git a/unimernet/datasets/datasets/formula.py b/unimernet/datasets/datasets/formula.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a628b149099f621502eb014cd8747545497f57a
--- /dev/null
+++ b/unimernet/datasets/datasets/formula.py
@@ -0,0 +1,71 @@
+import torch
+from .base_dataset import BaseDataset
+import os.path as osp
+import glob
+from io import BytesIO
+from PIL import Image
+
+
+class Im2LatexDataset(BaseDataset):
+
+ def init_samples(self):
+ samples = []
+ for vis_root, anno_path in zip(self.vis_root, self.anno_path):
+ images = [path.replace('\\', '/') for path in glob.glob(osp.join(vis_root, '*.png'))]
+ indices = [int(osp.basename(img).split('.')[0]) for img in images]
+
+ eqs = open(anno_path, 'r').read().split('\n')
+ eqs = [eqs[_] for _ in indices]
+
+ for i, e in zip(images, eqs):
+ samples.append({"image": i, "equation": e, "vis_root": vis_root})
+ return samples
+
+ def __getitem__(self, index):
+ ann = self.samples[index]
+ try:
+ image = self.vis_processor(self._read_image(ann))
+ except Exception:
+ return self[(index + 1) % len(self)]
+ if image is None:
+ return self[(index + 1) % len(self)]
+ equation = ann["equation"]
+ return {"image": image, "text_input": equation, "id": index}
+
+ def _read_image(self, sample, image_key="image"):
+ img_file = sample[image_key]
+ vis_root = sample["vis_root"]
+ image_path = osp.join(vis_root, img_file)
+ image = self.reader['body'](image_path)
+ if isinstance(image, bytes):
+ bytes_stream = BytesIO(image)
+ image = Image.open(bytes_stream)
+ image = image.convert("RGB")
+ return image
+
+ def init_reader(self):
+ if not isinstance(self.vis_root, str):
+ vis_root = self.vis_root[0]
+ else:
+ vis_root = self.vis_root
+ if vis_root.startswith('cluster'):
+ from petrel_client.client import Client
+ client = Client("~/petreloss.conf")
+ reader = {'type': 'PetrelReader', 'body': client.get}
+ else:
+ reader = {'type': 'LocalReader', 'body': Image.open}
+ return reader
+
+ def collater(self, samples):
+ image_list, question_list, id_list = [], [], []
+
+ for sample in samples:
+ image_list.append(sample["image"])
+ question_list.append(sample["text_input"])
+ id_list.append(sample["id"])
+
+ return {
+ "image": torch.stack(image_list, dim=0),
+ "text_input": question_list,
+ "id": id_list
+ }
diff --git a/unimernet/datasets/datasets/formula_multi_scale.py b/unimernet/datasets/datasets/formula_multi_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ab04648063ef7c3ef5b2586d80d7fe35e010e8
--- /dev/null
+++ b/unimernet/datasets/datasets/formula_multi_scale.py
@@ -0,0 +1,32 @@
+import torch
+from .formula import Im2LatexDataset
+
+
+class MultiScaleIm2LatexDataset(Im2LatexDataset):
+
+ def __getitem__(self, index):
+ ann = self.samples[index]
+ try:
+ pil_image = self._read_image(ann)
+ image = self.vis_processor(pil_image)
+ except Exception:
+ return self[(index + 1) % len(self)]
+ if image is None:
+ return self[(index + 1) % len(self)]
+ equation = ann["equation"]
+ return {"image": image, "text_input": equation, "id": index, "raw_image": pil_image}
+
+ def collater(self, samples):
+ self.vis_processor.reset_scale()
+ image_list, question_list, id_list = [], [], []
+
+ for sample in samples:
+ image_list.append(self.vis_processor(sample["raw_image"]))
+ question_list.append(sample["text_input"])
+ id_list.append(sample["id"])
+
+ return {
+ "image": torch.stack(image_list, dim=0),
+ "text_input": question_list,
+ "id": id_list
+ }
diff --git a/unimernet/models/__init__.py b/unimernet/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b86bd4f5376f3b0597b1f008ac893985bc0a06fc
--- /dev/null
+++ b/unimernet/models/__init__.py
@@ -0,0 +1,198 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import torch
+from omegaconf import OmegaConf
+from unimernet.common.registry import registry
+
+from unimernet.models.base_model import BaseModel
+
+from unimernet.processors.base_processor import BaseProcessor
+from unimernet.models.unimernet.unimernet import UniMERModel
+
+__all__ = [
+ "load_model",
+ "BaseModel",
+ "UniMERModel",
+]
+
+
+def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
+ """
+ Load supported models.
+
+ To list all available models and types in registry:
+ >>> from unimernet.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+ checkpoint (str): path or to checkpoint. Default: None.
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
+
+ Returns:
+ model (torch.nn.Module): model.
+ """
+
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
+
+ if checkpoint is not None:
+ model.load_checkpoint(checkpoint)
+
+ if is_eval:
+ model.eval()
+
+ if device == "cpu":
+ model = model.float()
+
+ return model.to(device)
+
+
+def load_preprocess(config):
+ """
+ Load preprocessor configs and construct preprocessors.
+
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
+
+ Args:
+ config (dict): preprocessor configs.
+
+ Returns:
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
+ """
+
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else BaseProcessor()
+ )
+
+ vis_processors = dict()
+ txt_processors = dict()
+
+ vis_proc_cfg = config.get("vis_processor")
+ txt_proc_cfg = config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+ else:
+ vis_train_cfg = None
+ vis_eval_cfg = None
+
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+ else:
+ txt_train_cfg = None
+ txt_eval_cfg = None
+
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
+
+ return vis_processors, txt_processors
+
+
+def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
+ """
+ Load model and its related preprocessors.
+
+ List all available models and types in registry:
+ >>> from unimernet.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+
+ Returns:
+ model (torch.nn.Module): model.
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+ """
+ model_cls = registry.get_model_class(name)
+
+ # load model
+ model = model_cls.from_pretrained(model_type=model_type)
+
+ if is_eval:
+ model.eval()
+
+ # load preprocess
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
+ if cfg is not None:
+ preprocess_cfg = cfg.preprocess
+
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
+ else:
+ vis_processors, txt_processors = None, None
+ logging.info(
+ f"""No default preprocess for model {name} ({model_type}).
+ This can happen if the model is not finetuned on downstream datasets,
+ or it is not intended for direct use without finetuning.
+ """
+ )
+
+ if device == "cpu" or device == torch.device("cpu"):
+ model = model.float()
+
+ return model.to(device), vis_processors, txt_processors
+
+
+class ModelZoo:
+ """
+ A utility class to create string representation of available model architectures and types.
+
+ >>> from unimernet.models import model_zoo
+ >>> # list all available models
+ >>> print(model_zoo)
+ >>> # show total number of models
+ >>> print(len(model_zoo))
+ """
+
+ def __init__(self) -> None:
+ self.model_zoo = {
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
+ for k, v in registry.mapping["model_name_mapping"].items()
+ }
+
+ def __str__(self) -> str:
+ return (
+ "=" * 50
+ + "\n"
+ + f"{'Architectures':<30} {'Types'}\n"
+ + "=" * 50
+ + "\n"
+ + "\n".join(
+ [
+ f"{name:<30} {', '.join(types)}"
+ for name, types in self.model_zoo.items()
+ ]
+ )
+ )
+
+ def __iter__(self):
+ return iter(self.model_zoo.items())
+
+ def __len__(self):
+ return sum([len(v) for v in self.model_zoo.values()])
+
+
+model_zoo = ModelZoo()
diff --git a/unimernet/models/__pycache__/__init__.cpython-310.pyc b/unimernet/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b70668ef663906775cebc29ec58d545e464b378
Binary files /dev/null and b/unimernet/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/base_model.cpython-310.pyc b/unimernet/models/__pycache__/base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83f3bab0ee55232ac666e7064ab83f8228d70984
Binary files /dev/null and b/unimernet/models/__pycache__/base_model.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/clip_vit.cpython-310.pyc b/unimernet/models/__pycache__/clip_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c0d8c99acd91ce10388b7224ac0bf1d30f10f7a
Binary files /dev/null and b/unimernet/models/__pycache__/clip_vit.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/eva_vit.cpython-310.pyc b/unimernet/models/__pycache__/eva_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..021e9e282399735f4d525e5319132debb3433568
Binary files /dev/null and b/unimernet/models/__pycache__/eva_vit.cpython-310.pyc differ
diff --git a/unimernet/models/base_model.py b/unimernet/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..104fae583941d9192002cf7d2196fcf92d0f28e9
--- /dev/null
+++ b/unimernet/models/base_model.py
@@ -0,0 +1,251 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from unimernet.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
+from unimernet.common.utils import get_abs_path, is_url
+from omegaconf import OmegaConf
+
+
+class BaseModel(nn.Module):
+ """Base class for models."""
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+ def load_checkpoint(self, url_or_filename):
+ """
+ Load from a finetuned checkpoint.
+
+ This should expect no mismatch in the model keys and the checkpoint keys.
+ """
+
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ if "model" in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ # logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info(f"Missing keys exist when loading '{url_or_filename}'.")
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+ @classmethod
+ def from_pretrained(cls, model_type):
+ """
+ Build a pretrained model from default configuration file, specified by model_type.
+
+ Args:
+ - model_type (str): model type, specifying architecture and checkpoints.
+
+ Returns:
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
+ """
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
+ model = cls.from_config(model_cfg)
+
+ return model
+
+ @classmethod
+ def default_config_path(cls, model_type):
+ assert (
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
+ ), "Unknown model type {}".format(model_type)
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
+
+ def load_checkpoint_from_config(self, cfg, **kwargs):
+ """
+ Load checkpoint as specified in the config file.
+
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
+ When loading the pretrained model, each task-specific architecture may define their
+ own load_from_pretrained() method.
+ """
+ load_pretrained = cfg.get("load_pretrained", True)
+ load_finetuned = cfg.get("load_finetuned", False)
+
+ if load_pretrained:
+ # load pre-trained weights
+ pretrain_path = cfg.get("pretrained", None)
+ assert pretrain_path, "Found load_finetuned is False, but pretrain_path is None."
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
+ logging.info(f"Loaded pretrained model '{pretrain_path}'.")
+
+ if load_finetuned:
+ finetune_path = cfg.get("finetuned", None)
+ assert finetune_path is not None, "Found load_finetuned is True, but finetune_path is None."
+ self.load_checkpoint(url_or_filename=finetune_path)
+ logging.info(f"Loaded finetuned model '{finetune_path}'.")
+
+ def before_evaluation(self, **kwargs):
+ pass
+
+ def show_n_params(self, return_str=True):
+ tot = 0
+ for p in self.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return "{:.1f}M".format(tot / 1e6)
+ else:
+ return "{:.1f}K".format(tot / 1e3)
+ else:
+ return tot
+
+
+class BaseEncoder(nn.Module):
+ """
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward_features(self, samples, **kwargs):
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+
+class SharedQueueMixin:
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.queue_ptr)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr: ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr: ptr + batch_size] = text_feats.T
+
+ if idxs is not None:
+ idxs = concat_all_gather(idxs)
+ self.idx_queue[:, ptr: ptr + batch_size] = idxs.T
+
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+ self.queue_ptr[0] = ptr
+
+
+class MomentumDistilationMixin:
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(
+ model_pair[0].parameters(), model_pair[1].parameters()
+ ):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(
+ model_pair[0].parameters(), model_pair[1].parameters()
+ ):
+ param_m.data = param_m.data * self.momentum + param.data * (
+ 1.0 - self.momentum
+ )
+
+
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ torch.distributed.all_reduce(all_gradients)
+ return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = torch.distributed.get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+
+ # tensor_all = GatherLayer.apply(tensors)
+ tensor_all = GatherLayer.apply(tensors)
+
+ return torch.cat(tensor_all, dim=0)
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ # if use distributed training
+ if not is_dist_avail_and_initialized():
+ return tensor
+
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def tile(x, dim, n_tile):
+ init_dim = x.size(dim)
+ repeat_idx = [1] * x.dim()
+ repeat_idx[dim] = n_tile
+ x = x.repeat(*(repeat_idx))
+ order_index = torch.LongTensor(
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
+ )
+ return torch.index_select(x, dim, order_index.to(x.device))
diff --git a/unimernet/models/blip2_models/Qformer.py b/unimernet/models/blip2_models/Qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603
--- /dev/null
+++ b/unimernet/models/blip2_models/Qformer.py
@@ -0,0 +1,1216 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+ )
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = (
+ self.self.attention_head_size * self.self.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if (
+ self.config.add_cross_attention
+ and layer_num % self.config.cross_attention_freq == 0
+ ):
+ self.crossattention = BertAttention(
+ config, is_cross_attention=self.config.add_cross_attention
+ )
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = (
+ past_key_value[:2] if past_key_value is not None else None
+ )
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = (
+ outputs + cross_attention_outputs[1:-1]
+ ) # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ () if output_attentions and self.config.add_cross_attention else None
+ )
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(
+ *inputs, past_key_value, output_attentions, query_length
+ )
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert (
+ query_embeds is not None
+ ), "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] - self.config.query_length
+ if past_key_values is not None
+ else 0
+ )
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+ 0
+ ].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return (
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/unimernet/models/blip2_models/__init__.py b/unimernet/models/blip2_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a040440a336c86164ff02f23bc90c0af3046ccb
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66b2031f3f5efe47f280420ea496156ea14f4fe1
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b46fcb658d78c631ebcd9a06309f3883c35f6786
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/blip2.py b/unimernet/models/blip2_models/blip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3829d58c1a97d49893566488aafc95f1e4c8d458
--- /dev/null
+++ b/unimernet/models/blip2_models/blip2.py
@@ -0,0 +1,322 @@
+"""
+ Copyright (c) 2023, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import contextlib
+import logging
+import os
+import time
+import datetime
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.nn.functional as F
+
+import unimernet.common.dist_utils as dist_utils
+from unimernet.common.dist_utils import download_cached_file
+from unimernet.common.utils import is_url
+from unimernet.common.logger import MetricLogger
+from unimernet.models.base_model import BaseModel
+from unimernet.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
+from unimernet.models.eva_vit import create_eva_vit_g
+from unimernet.models.clip_vit import create_clip_vit_L
+from transformers import BertTokenizer
+from transformers.utils import logging as tf_logging
+
+tf_logging.set_verbosity_error()
+
+
+class Blip2Base(BaseModel):
+ @classmethod
+ def init_tokenizer(cls, truncation_side="right"):
+ tokenizer = BertTokenizer.from_pretrained("/mnt/lustre/hanxiao/work/bert-base-uncased", truncation_side=truncation_side)
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+ return tokenizer
+
+ def maybe_autocast(self, dtype=torch.float16):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ enable_autocast = self.device != torch.device("cpu")
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ @classmethod
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
+ encoder_config = BertConfig.from_pretrained("/mnt/lustre/hanxiao/work/bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel.from_pretrained(
+ "/mnt/lustre/hanxiao/work/bert-base-uncased", config=encoder_config
+ )
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ return Qformer, query_tokens
+
+ def init_vision_encoder(
+ self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
+ ):
+ assert model_name in [
+ "eva_clip_g",
+ "eva2_clip_L",
+ "clip_L",
+ ], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
+ if model_name == "eva_clip_g":
+ visual_encoder = create_eva_vit_g(
+ img_size, drop_path_rate, use_grad_checkpoint, precision
+ )
+ # elif model_name == "eva2_clip_L":
+ # visual_encoder = create_eva2_vit_L(
+ # img_size, drop_path_rate, use_grad_checkpoint, precision
+ # )
+ elif model_name == "clip_L":
+ visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
+ ln_vision = LayerNorm(visual_encoder.num_features)
+ self.vit_name = model_name
+ return visual_encoder, ln_vision
+
+ def load_from_pretrained(self, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ # logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+ def get_optimizer_params(self, weight_decay, lr_scale=1):
+ if self.vit_name == "eva_clip_g":
+ vit_num_layers = self.visual_encoder.get_num_layer()
+ lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
+
+ parameter_group_names = {}
+ parameter_group_vars = {}
+
+ for name, param in self.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias"):
+ group_name = "no_decay"
+ this_weight_decay = 0.
+ else:
+ group_name = "decay"
+ this_weight_decay = weight_decay
+ if 'visual_encoder' in name:
+ layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.', ''))
+ group_name = "vit_layer_%d_%s" % (layer_id, group_name)
+ else:
+ layer_id = None
+
+ if group_name not in parameter_group_names:
+ if layer_id is not None:
+ scale = lr_scales[layer_id]
+ else:
+ scale = 1
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+ # import json
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ optim_params = list(parameter_group_vars.values())
+ return optim_params
+ else:
+ return super().get_optimizer_params(weight_decay, lr_scale)
+
+ def _lemmatize(self, answers):
+ def apply(answer):
+ doc = self.lemmatizer(answer)
+
+ words = []
+ for token in doc:
+ if token.pos_ in ["NOUN", "VERB"]:
+ words.append(token.lemma_)
+ else:
+ words.append(token.text)
+ answer = " ".join(words)
+
+ return answer
+
+ return [apply(answer) for answer in answers]
+
+ @property
+ def lemmatizer(self):
+ if self._lemmatizer is None:
+ try:
+ import spacy
+
+ self._lemmatizer = spacy.load("en_core_web_sm")
+ except ImportError:
+ logging.error(
+ """
+ Please install spacy and en_core_web_sm model to apply lemmatization.
+ python -m spacy download en_core_web_sm
+ OR
+ import spacy.cli
+ spacy.cli.download("en_core_web_sm")
+ """
+ )
+ exit(1)
+
+ return self._lemmatizer
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+def compute_sim_matrix(model, data_loader, **kwargs):
+ k_test = kwargs.pop("k_test")
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+
+ logging.info("Computing features for evaluation...")
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i: min(num_text, i + text_bs)]
+ text_input = model.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=35,
+ return_tensors="pt",
+ ).to(model.device)
+ text_feat = model.forward_text(text_input)
+ text_embed = F.normalize(model.text_proj(text_feat))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds, dim=0)
+ text_ids = torch.cat(text_ids, dim=0)
+ text_atts = torch.cat(text_atts, dim=0)
+
+ vit_feats = []
+ image_embeds = []
+ for samples in data_loader:
+ image = samples["image"]
+
+ image = image.to(model.device)
+ image_feat, vit_feat = model.forward_image(image)
+ image_embed = model.vision_proj(image_feat)
+ image_embed = F.normalize(image_embed, dim=-1)
+
+ vit_feats.append(vit_feat.cpu())
+ image_embeds.append(image_embed)
+
+ vit_feats = torch.cat(vit_feats, dim=0)
+ image_embeds = torch.cat(image_embeds, dim=0)
+
+ sims_matrix = []
+ for image_embed in image_embeds:
+ sim_q2t = image_embed @ text_embeds.t()
+ sim_i2t, _ = sim_q2t.max(0)
+ sims_matrix.append(sim_i2t)
+ sims_matrix = torch.stack(sims_matrix, dim=0)
+
+ score_matrix_i2t = torch.full(
+ (len(data_loader.dataset.image), len(texts)), -100.0
+ ).to(model.device)
+
+ num_tasks = dist_utils.get_world_size()
+ rank = dist_utils.get_rank()
+ step = sims_matrix.size(0) // num_tasks + 1
+ start = rank * step
+ end = min(sims_matrix.size(0), start + step)
+
+ for i, sims in enumerate(
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
+ ):
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
+ score = model.compute_itm(
+ image_inputs=image_inputs,
+ text_ids=text_ids[topk_idx],
+ text_atts=text_atts[topk_idx],
+ ).float()
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2i = torch.full(
+ (len(texts), len(data_loader.dataset.image)), -100.0
+ ).to(model.device)
+
+ step = sims_matrix.size(0) // num_tasks + 1
+ start = rank * step
+ end = min(sims_matrix.size(0), start + step)
+
+ for i, sims in enumerate(
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
+ ):
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
+ score = model.compute_itm(
+ image_inputs=image_inputs,
+ text_ids=text_ids[start + i].repeat(k_test, 1),
+ text_atts=text_atts[start + i].repeat(k_test, 1),
+ ).float()
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
+
+ if dist_utils.is_dist_avail_and_initialized():
+ dist.barrier()
+ torch.distributed.all_reduce(
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
+ )
+ torch.distributed.all_reduce(
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Evaluation time {}".format(total_time_str))
+
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
diff --git a/unimernet/models/blip2_models/blip2_vicuna_instruct.py b/unimernet/models/blip2_models/blip2_vicuna_instruct.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5e257b74c8d17db444e4aa78c06956a2f83027
--- /dev/null
+++ b/unimernet/models/blip2_models/blip2_vicuna_instruct.py
@@ -0,0 +1,666 @@
+"""
+Requires Transformer 4.28 and above, implementation may change according the Llama implementation
+"""
+import logging
+from packaging import version
+
+import torch
+import torch.nn as nn
+
+import transformers
+
+from unimernet.common.registry import registry
+from unimernet.models.blip2_models.blip2 import Blip2Base, disabled_train
+
+
+@registry.register_model("blip2_vicuna_instruct")
+class Blip2VicunaInstruct(Blip2Base):
+ """
+ BLIP2 Vicuna model.
+ Supported model types:
+ - vicuna7b
+ - vicuna13b
+ Usage:
+ >>> from unimernet.models import load_model
+ >>> model = load_model("blip2_vicuna_instruct", "vicuna7b")
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "vicuna7b": "configs/models/blip2_instruct_vicuna7b.yaml",
+ "vicuna13b": "configs/models/blip2_instruct_vicuna13b.yaml",
+ "minigpt4_vicuna7b": "configs/models/mini_gpt4_vicuna7b.yaml",
+ "minigpt4_vicuna13b": "configs/models/mini_gpt4_vicuna13b.yaml",
+ }
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ freeze_vit_ln=False,
+ num_query_token=32,
+ llm_model="",
+ prompt="",
+ max_txt_len=128,
+ max_output_txt_len=256,
+ apply_lemmatizer=False,
+ qformer_text_input=True,
+ truncate_q_former_output=True
+ ):
+ super().__init__()
+ transformers_version = version.parse(transformers.__version__)
+ assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
+ from transformers import LlamaTokenizer
+ from unimernet.models.blip2_models.modeling_llama import LlamaForCausalLM
+
+ self.tokenizer = self.init_tokenizer(truncation_side="left")
+
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+ )
+ if freeze_vit:
+ for name, param in self.visual_encoder.named_parameters():
+ param.requires_grad = False
+ self.visual_encoder = self.visual_encoder.eval()
+ self.visual_encoder.train = disabled_train
+ logging.info("freeze vision encoder")
+
+ if freeze_vit_ln:
+ for name, param in self.ln_vision.named_parameters():
+ param.requires_grad = False
+ self.ln_vision = self.ln_vision.eval()
+ self.ln_vision.train = disabled_train
+ logging.info("freeze vit layner norm")
+
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features
+ )
+
+ if not qformer_text_input:
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ else:
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
+ self.Qformer.cls = None
+
+ self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
+ self.llm_tokenizer_for_generate = LlamaTokenizer.from_pretrained(llm_model, use_fast=False,
+ truncation_side="left")
+ self.llm_model = LlamaForCausalLM.from_pretrained(
+ llm_model, torch_dtype=torch.float16
+ )
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+ self.llm_tokenizer.add_special_tokens({'bos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'eos_token': ''})
+ self.llm_tokenizer.add_special_tokens({'unk_token': ''})
+ # self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
+
+ self.llm_tokenizer_for_generate.add_special_tokens({'pad_token': '[PAD]'})
+ self.llm_tokenizer_for_generate.add_special_tokens({'bos_token': ''})
+ self.llm_tokenizer_for_generate.add_special_tokens({'eos_token': ''})
+ self.llm_tokenizer_for_generate.add_special_tokens({'unk_token': ''})
+ self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
+
+ # self.eos_token_id = self.llm_tokenizer(
+ # self.llm_tokenizer.eos_token, add_special_tokens=False
+ # ).input_ids[0]
+
+ for name, param in self.llm_model.named_parameters():
+ param.requires_grad = False
+
+ self.llm_proj = nn.Linear(
+ self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
+ )
+
+ self.max_txt_len = max_txt_len
+ self.max_output_txt_len = max_output_txt_len
+ self.prompt = prompt
+ prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
+ self.prompt_length = prompt_tokens.attention_mask.sum(1)
+
+ self._lemmatizer = None
+
+ self.qformer_text_input = qformer_text_input
+ self.truncate_q_former_output = truncate_q_former_output
+
+ def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
+ input_part_targets_len = []
+ llm_tokens = {"input_ids": [], "attention_mask": []}
+ for i in range(input_ids.size(0)):
+ this_input_ones = input_atts[i].sum()
+ input_part_targets_len.append(this_input_ones)
+ llm_tokens['input_ids'].append(
+ torch.cat([
+ input_ids[i][:this_input_ones],
+ output_ids[i][1:],
+ input_ids[i][this_input_ones:]
+ ])
+ )
+ llm_tokens['attention_mask'].append(
+ torch.cat([
+ input_atts[i][:this_input_ones],
+ output_atts[i][1:],
+ input_atts[i][this_input_ones:]
+ ])
+ )
+ llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+ llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+ return llm_tokens, input_part_targets_len
+
+ def forward(self, samples):
+ # print('-----------------')
+ # print(samples["text_input"])
+ # print(samples["text_output"])
+ # print('-----------------')
+
+ image = samples["image"]
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+ bs = image.size(0)
+
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ if self.qformer_text_input:
+ text_Qformer = self.tokenizer(
+ samples["text_input"],
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ return_tensors="pt",
+ ).to(image.device)
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+ query_output = self.Qformer.bert(
+ text_Qformer.input_ids,
+ attention_mask=Qformer_atts,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ else:
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ if self.truncate_q_former_output:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+ else:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state)
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+ self.llm_tokenizer.padding_side = "right"
+ self.llm_tokenizer.truncation_side = 'left'
+ text_input_tokens = self.llm_tokenizer(
+ samples['text_input'],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_txt_len,
+ ).to(image.device)
+
+ self.llm_tokenizer.truncation_side = 'right'
+ text_output_tokens = self.llm_tokenizer(
+ [t + self.llm_tokenizer.eos_token for t in samples['text_output']],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_output_txt_len,
+ ).to(image.device)
+
+ llm_tokens, input_part_targets_len = self.concat_text_input_output(
+ text_input_tokens.input_ids,
+ text_input_tokens.attention_mask,
+ text_output_tokens.input_ids,
+ text_output_tokens.attention_mask,
+ )
+
+ # do not apply loss to the padding
+ targets = llm_tokens['input_ids'].masked_fill(
+ llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+ )
+
+ # do not apply loss to the text input (i.e., instruction)
+ for i, l in enumerate(input_part_targets_len):
+ targets[i][:l] = -100
+
+ # do not apply loss to the query tokens
+ empty_targets = (
+ torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+ )
+ targets = torch.cat([empty_targets, targets], dim=1)
+
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+ attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
+
+ with self.maybe_autocast():
+ outputs = self.llm_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ use_cache=False,
+ )
+
+ loss = outputs.loss
+
+ return {"loss": loss}
+
+ def get_vision_feats(self, image, prompt):
+ bs = image.size(0)
+
+ if isinstance(prompt, str):
+ prompt = [prompt] * bs
+ else:
+ assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+ query_tokens = self.query_tokens.expand(bs, -1, -1)
+
+ text_Qformer = self.tokenizer(
+ prompt,
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ return_tensors="pt",
+ ).to(image.device)
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+ query_output = self.Qformer.bert(
+ text_Qformer.input_ids,
+ attention_mask=Qformer_atts,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ if self.truncate_q_former_output:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+ else:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state)
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+ return inputs_llm, atts_llm
+
+ def shift_padding_to_left(self, inputs_embeds, attention_mask):
+ llm_tokens = {"input_embeds": [], "attention_mask": []}
+ for i in range(inputs_embeds.size(0)):
+ this_input_ones = attention_mask[i].sum()
+ llm_tokens['input_embeds'].append(
+ torch.cat([
+ inputs_embeds[i][this_input_ones:],
+ inputs_embeds[i][:this_input_ones],
+ ])
+ )
+ llm_tokens['attention_mask'].append(
+ torch.cat([
+ attention_mask[i][this_input_ones:],
+ attention_mask[i][:this_input_ones],
+ ])
+ )
+ llm_tokens['input_embeds'] = torch.stack(llm_tokens['input_embeds'])
+ llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+ return llm_tokens['input_embeds'], llm_tokens['attention_mask']
+
+ @torch.no_grad()
+ def generate(
+ self,
+ samples,
+ use_nucleus_sampling=False,
+ num_beams=5,
+ max_length=256,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1.5,
+ length_penalty=1,
+ num_captions=1,
+ temperature=1,
+ ):
+
+ if "prompt" in samples.keys():
+ prompt = samples["prompt"]
+ else:
+ prompt = self.prompt
+
+ image = samples["image"]
+
+ inputs_llm, atts_llm = self.get_vision_feats(image, prompt)
+
+ self.llm_tokenizer_for_generate.padding_side = "right"
+
+ self.llm_tokenizer_for_generate.pad_token = self.llm_tokenizer_for_generate.eos_token # debug
+ ori_pad_token_id = self.llm_model.config.pad_token_id
+ self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id # debug
+
+ if "prefix" in samples:
+ prompt = [f"{prompt_} {prefix_}".strip() for prompt_, prefix_ in zip(prompt, samples["prefix"])]
+
+ llm_tokens = self.llm_tokenizer_for_generate(
+ prompt,
+ padding="longest",
+ return_tensors="pt",
+ ).to(image.device)
+
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+ inputs_embeds = inputs_embeds.to(next(self.llm_model.parameters()).dtype)
+ attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
+ inputs_embeds, attention_mask = self.shift_padding_to_left(inputs_embeds, attention_mask)
+
+ with self.maybe_autocast():
+ outputs = self.llm_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ do_sample=use_nucleus_sampling,
+ top_p=top_p,
+ temperature=temperature,
+ num_beams=num_beams,
+ max_length=max_length,
+ min_length=min_length,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ num_return_sequences=num_captions,
+ use_cache=True
+ )
+
+ outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
+ outputs[outputs == -1] = 1 # debug
+ output_text = self.llm_tokenizer_for_generate.batch_decode(outputs, skip_special_tokens=True)
+ output_text = [text.strip() for text in output_text]
+
+ self.llm_model.config.pad_token_id = ori_pad_token_id
+
+ return output_text
+
+ @torch.no_grad()
+ def generate_multi(
+ self,
+ samples,
+ use_nucleus_sampling=False,
+ num_beams=5,
+ max_length=256,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1.5,
+ length_penalty=1,
+ temperature=1,
+ ):
+
+ if "prompt" in samples.keys():
+ prompt = samples["prompt"]
+ else:
+ prompt = self.prompt
+
+ image = samples["image"]
+
+ inputs_llm, atts_llm = self.get_vision_feats(image, prompt)
+
+ self.llm_tokenizer_for_generate.padding_side = "right"
+
+ self.llm_tokenizer_for_generate.pad_token = self.llm_tokenizer_for_generate.eos_token # debug
+ ori_pad_token_id = self.llm_model.config.pad_token_id
+ self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id # debug
+
+ if "prefix" in samples:
+ prompt = [f"{prompt_} {prefix_}".strip() for prompt_, prefix_ in zip(prompt, samples["prefix"])]
+
+ llm_tokens = self.llm_tokenizer_for_generate(
+ prompt,
+ padding="longest",
+ return_tensors="pt",
+ ).to(image.device)
+
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+ inputs_embeds = inputs_embeds.to(next(self.llm_model.parameters()).dtype)
+ attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
+ inputs_embeds, attention_mask = self.shift_padding_to_left(inputs_embeds, attention_mask)
+
+ with self.maybe_autocast():
+ raw_output = self.llm_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ do_sample=use_nucleus_sampling,
+ top_p=top_p,
+ temperature=temperature,
+ num_beams=num_beams,
+ max_length=max_length,
+ min_length=min_length,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ num_return_sequences=num_beams,
+ output_scores=True,
+ return_dict_in_generate=True,
+ use_cache=True
+ )
+ outputs = raw_output.sequences
+ outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
+ outputs[outputs == -1] = 1 # debug
+ output_text = self.llm_tokenizer_for_generate.batch_decode(outputs, skip_special_tokens=True)
+
+ output_text = [text.strip() for text in output_text]
+ scores = torch.exp(raw_output.sequences_scores).cpu().numpy() ** 3 * 100 # TODO
+
+ all_texts = []
+ all_scores = []
+ for i in range(0, len(output_text), num_beams):
+ this_text = output_text[i:i + num_beams]
+ all_texts.append(this_text)
+ this_score = scores[i: i + num_beams]
+ all_scores.append(this_score)
+
+ self.llm_model.config.pad_token_id = ori_pad_token_id
+
+ return all_texts, all_scores
+
+ def predict_by_rank(
+ self,
+ samples,
+ **kwargs
+ ):
+ image = samples["image"]
+ prompt = samples["prompt"]
+ candidates = samples["candidates"][0]
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ assert image.size(0) == len(prompt) == 1, "When doing predict by rank, the batch size must be 1."
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+ batch_size = len(candidates)
+
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ if self.qformer_text_input:
+ text_Qformer = self.tokenizer(
+ prompt,
+ padding='longest',
+ truncation=True,
+ max_length=self.max_txt_len,
+ return_tensors="pt",
+ ).to(image.device)
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+ query_output = self.Qformer.bert(
+ text_Qformer.input_ids,
+ attention_mask=Qformer_atts,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ else:
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ if self.truncate_q_former_output:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+ else:
+ inputs_llm = self.llm_proj(query_output.last_hidden_state)
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+ self.llm_tokenizer.padding_side = "right"
+ self.llm_tokenizer.truncation_side = 'left'
+ text_input_tokens = self.llm_tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_txt_len,
+ ).to(image.device)
+
+ inputs_llm = inputs_llm.repeat(batch_size, 1, 1)
+ atts_llm = atts_llm.repeat(batch_size, 1)
+ text_input_ids = text_input_tokens.input_ids.repeat(batch_size, 1)
+ text_input_mask = text_input_tokens.attention_mask.repeat(batch_size, 1)
+
+ self.llm_tokenizer.truncation_side = 'right'
+ text_output_tokens = self.llm_tokenizer(
+ [t + self.llm_tokenizer.eos_token for t in candidates],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_output_txt_len,
+ ).to(image.device)
+
+ llm_tokens, input_part_targets_len = self.concat_text_input_output(
+ text_input_ids,
+ text_input_mask,
+ text_output_tokens.input_ids,
+ text_output_tokens.attention_mask,
+ )
+
+ # do not apply loss to the padding
+ targets = llm_tokens['input_ids'].masked_fill(
+ llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+ )
+
+ # do not apply loss to the text input (i.e., instruction)
+ for i, l in enumerate(input_part_targets_len):
+ targets[i][:l] = -100
+
+ # do not apply loss to the query tokens
+ empty_targets = (
+ torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+ )
+ targets = torch.cat([empty_targets, targets], dim=1)
+
+ inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+ attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
+
+ with self.maybe_autocast():
+ outputs = self.llm_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ reduction="none",
+ use_cache=False
+ )
+
+ loss = outputs.loss.view(batch_size)
+ top1 = int(torch.argsort(loss, dim=-1)[0])
+
+ return [candidates[top1]]
+
+ def _lemmatize(self, answers):
+ def apply(answer):
+ doc = self.lemmatizer(answer)
+
+ words = []
+ for token in doc:
+ if token.pos_ in ["NOUN", "VERB"]:
+ words.append(token.lemma_)
+ else:
+ words.append(token.text)
+ answer = " ".join(words)
+
+ return answer
+
+ return [apply(answer) for answer in answers]
+
+ @property
+ def lemmatizer(self):
+ if self._lemmatizer is None:
+ try:
+ import spacy
+
+ self._lemmatizer = spacy.load("en_core_web_sm")
+ except ImportError:
+ logging.error(
+ """
+ Please install spacy and en_core_web_sm model to apply lemmatization.
+ python -m spacy download en_core_web_sm
+ OR
+ import spacy.cli
+ spacy.cli.download("en_core_web_sm")
+ """
+ )
+ exit(1)
+
+ return self._lemmatizer
+
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ img_size = cfg.get("image_size")
+ num_query_token = cfg.get("num_query_token")
+ llm_model = cfg.get("llm_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ freeze_vit_ln = cfg.get("freeze_vit_ln", False)
+ prompt = cfg.get("prompt", "")
+ max_txt_len = cfg.get("max_txt_len", 128)
+ max_output_txt_len = cfg.get("max_output_txt_len", 256)
+
+ apply_lemmatizer = cfg.get("apply_lemmatizer", False)
+
+ qformer_text_input = cfg.get("qformer_text_input", True)
+ truncate_q_former_output = cfg.get("truncate_q_former_output", True)
+
+ model = cls(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ freeze_vit_ln=freeze_vit_ln,
+ num_query_token=num_query_token,
+ llm_model=llm_model,
+ prompt=prompt,
+ max_txt_len=max_txt_len,
+ max_output_txt_len=max_output_txt_len,
+ apply_lemmatizer=apply_lemmatizer,
+ qformer_text_input=qformer_text_input,
+ truncate_q_former_output=truncate_q_former_output
+ )
+
+ model.load_checkpoint_from_config(cfg)
+
+ return model
diff --git a/unimernet/models/blip2_models/modeling_llama.py b/unimernet/models/blip2_models/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d16a4abfb0a83dc416888755e31ea55c5be02b
--- /dev/null
+++ b/unimernet/models/blip2_models/modeling_llama.py
@@ -0,0 +1,994 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
+ SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, \
+ replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+from einops import rearrange
+
+FLASH_ATTN_FLAG = True
+try:
+ from flash_attn.flash_attn_interface import ( # pip3 install "flash-attn>=2.0"
+ flash_attn_varlen_qkvpacked_func,
+ )
+ from flash_attn.bert_padding import unpad_input, pad_input
+
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ logging.warning(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ FLASH_ATTN_FLAG = False
+except ImportError:
+ FLASH_ATTN_FLAG = False
+ logging.warning("You haven't installed flash attention")
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def flash_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel
+
+ attention_mask: [bsz, q_len]
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ # [bsz, q_len, nh, hd]
+ # [bsz, nh, q_len, hd]
+
+ kv_seq_len = key_states.shape[-2]
+ assert past_key_value is None, "past_key_value is not supported"
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+ assert not output_attentions, "output_attentions is not supported"
+ assert not use_cache, "use_cache is not supported"
+
+ # Flash attention codes from
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+ # transform the data into the format required by flash attention
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+ max_s = q_len
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
+ )
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(
+ pad_input(
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
+ ),
+ "b s (h d) -> b s h d",
+ h=nheads,
+ )
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if FLASH_ATTN_FLAG and not use_cache:
+ return self.flash_attn_forward(hidden_states, attention_mask, position_ids, past_key_value,
+ output_attentions, use_cache)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ if not (FLASH_ATTN_FLAG and (use_cache is False)):
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ reduction: Optional[str] = "mean",
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(reduction=reduction)
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+ if reduction == "none":
+ # loss = loss.view(logits.size(0), -1).sum(1)
+ loss = loss.view(logits.size(0), -1).mean(1)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/unimernet/models/blip2_models/modeling_llama_.py b/unimernet/models/blip2_models/modeling_llama_.py
new file mode 100644
index 0000000000000000000000000000000000000000..372889e0f7495ae7db0cdcd1bebd748833f66e93
--- /dev/null
+++ b/unimernet/models/blip2_models/modeling_llama_.py
@@ -0,0 +1,885 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ reduction: Optional[str] = "mean",
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(reduction=reduction)
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+ if reduction == "none":
+ # loss = loss.view(logits.size(0), -1).sum(1)
+ loss = loss.view(logits.size(0), -1).mean(1)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
\ No newline at end of file
diff --git a/unimernet/models/clip_vit.py b/unimernet/models/clip_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5547f2756531e8db014a99ef7c70ee3a4ce1533
--- /dev/null
+++ b/unimernet/models/clip_vit.py
@@ -0,0 +1,254 @@
+from collections import OrderedDict
+from itertools import repeat
+import collections.abc
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+from unimernet.models.eva_vit import convert_weights_to_fp16
+from unimernet.common.dist_utils import download_cached_file
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu1(self.bn1(self.conv1(x)))
+ out = self.relu2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ if use_grad_checkpointing:
+ self.attn = checkpoint_wrapper(self.attn)
+ self.mlp = checkpoint_wrapper(self.mlp)
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.num_features = width
+ self.num_heads = heads
+ self.num_patches = (input_resolution // patch_size) ** 2
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
+
+# self.ln_final = LayerNorm(width)
+
+ def forward(self, x: torch.Tensor):
+
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+# x = self.ln_final(x)
+ return x
+
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_2tuple = _ntuple(2)
+def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get('positional_embedding', None)
+
+ grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
+ if old_pos_embed is None:
+ return
+ grid_size = to_2tuple(grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ align_corners=True,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict['positional_embedding'] = new_pos_embed
+
+
+def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
+ model = VisionTransformer(
+ input_resolution=img_size,
+ patch_size=14,
+ width=1024,
+ layers=23,
+ heads=16,
+ use_grad_checkpointing=use_checkpoint,
+ )
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
+ cached_file = download_cached_file(
+ url, check_hash=False, progress=True
+ )
+ state_dict = torch.load(cached_file, map_location="cpu")
+ interpolate_pos_embed(model,state_dict)
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+ # print(incompatible_keys)
+
+ if precision == "fp16":
+ convert_weights_to_fp16(model)
+ return model
diff --git a/unimernet/models/eva_vit.py b/unimernet/models/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c495afeb15a6f8be0159cbc4da95e43a3513a33d
--- /dev/null
+++ b/unimernet/models/eva_vit.py
@@ -0,0 +1,448 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+from unimernet.common.dist_utils import download_cached_file
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ **kwargs
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+ self.use_checkpoint = use_checkpoint
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+ for i in range(depth)])
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ # trunc_normal_(self.mask_token, std=.02)
+ # if isinstance(self.head, nn.Linear):
+ # trunc_normal_(self.head.weight, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ # if isinstance(self.head, nn.Linear):
+ # self.head.weight.data.mul_(init_scale)
+ # self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+ else:
+ x = blk(x, rel_pos_bias)
+ return x
+
+ # x = self.norm(x)
+
+ # if self.fc_norm is not None:
+ # t = x[:, 1:, :]
+ # return self.fc_norm(t.mean(1))
+ # else:
+ # return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ # x = self.head(x)
+ return x
+
+ def get_intermediate_layers(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ features = []
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias)
+ features.append(x)
+
+ return features
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ # tensor = getattr(l, attr)
+ # if tensor is not None:
+ # tensor.data = tensor.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
+ model = VisionTransformer(
+ img_size=img_size,
+ patch_size=14,
+ use_mean_pooling=False,
+ embed_dim=1408,
+ depth=39,
+ num_heads=1408 // 88,
+ mlp_ratio=4.3637,
+ qkv_bias=True,
+ drop_path_rate=drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ use_checkpoint=use_checkpoint,
+ )
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+ cached_file = download_cached_file(
+ url, check_hash=False, progress=True
+ )
+ state_dict = torch.load(cached_file, map_location="cpu")
+ interpolate_pos_embed(model, state_dict)
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+ # print(incompatible_keys)
+
+ if precision == "fp16":
+ # model.to("cuda")
+ convert_weights_to_fp16(model)
+ return model
diff --git a/unimernet/models/unimernet/__init__.py b/unimernet/models/unimernet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..474a8a8878e82ee7ab895af876e07d469038356d
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef80e0ababefc0707c88033d942ed541903fddd
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c69d47a15353280939c43d456713882fa257d69e
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..386cda1a50aff3a9cc6dc7a0007f44f06148a219
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a342676a6daacc8fa1142a0cebce81ebffc567a2
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1326b01ebe093a348d72b6c6bd79b9b9915f3f18
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5451cf820f71b4ada76773203ba3db52c27e6800
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e4fdac683ba4855eaab5174ed1e3798f8f56a59
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/configuration_unimernet_decoder.py b/unimernet/models/unimernet/configuration_unimernet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa794fd714f1cb018ec5b5e5d2173c9cd7855c3
--- /dev/null
+++ b/unimernet/models/unimernet/configuration_unimernet_decoder.py
@@ -0,0 +1,387 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MBART model configuration"""
+
+from collections import OrderedDict
+from typing import Any, Mapping, Optional
+
+from transformers import PreTrainedTokenizer
+from transformers.configuration_utils import PretrainedConfig
+from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
+from transformers.onnx.utils import compute_effective_axis_dimension
+from transformers.utils import TensorType, is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MBartConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the MBART
+ [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for classifier.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models)
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Example:
+
+ ```python
+ >>> from transformers import MBartConfig, MBartModel
+
+ >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
+ >>> configuration = MBartConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
+ >>> model = MBartModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mbart"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ use_cache=True,
+ is_encoder_decoder=True,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ forced_eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+
+# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart
+class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task in ["default", "seq2seq-lm"]:
+ common_inputs = OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+ ]
+ )
+
+ if self.use_past:
+ common_inputs["decoder_input_ids"] = {0: "batch"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
+ else:
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
+
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+ elif self.task == "causal-lm":
+ # TODO: figure this case out.
+ common_inputs = OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+ ]
+ )
+ if self.use_past:
+ num_encoder_layers, _ = self.num_layers
+ for i in range(num_encoder_layers):
+ common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
+ common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
+ else:
+ common_inputs = OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+ ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
+ ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
+ ]
+ )
+
+ return common_inputs
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task in ["default", "seq2seq-lm"]:
+ common_outputs = super().outputs
+ else:
+ common_outputs = super(OnnxConfigWithPast, self).outputs
+ if self.use_past:
+ num_encoder_layers, _ = self.num_layers
+ for i in range(num_encoder_layers):
+ common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
+ common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
+ return common_outputs
+
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+ tokenizer, batch_size, seq_length, is_pair, framework
+ )
+
+ # Generate decoder inputs
+ decoder_seq_length = seq_length if not self.use_past else 1
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
+ )
+ decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
+
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
+ num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
+ encoder_shape = (
+ batch,
+ num_encoder_attention_heads,
+ encoder_seq_length,
+ self._config.hidden_size // num_encoder_attention_heads,
+ )
+ decoder_past_length = decoder_seq_length + 3
+ decoder_shape = (
+ batch,
+ num_decoder_attention_heads,
+ decoder_past_length,
+ self._config.hidden_size // num_decoder_attention_heads,
+ )
+
+ common_inputs["decoder_attention_mask"] = torch.cat(
+ [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
+ )
+
+ common_inputs["past_key_values"] = []
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
+ num_encoder_layers, num_decoder_layers = self.num_layers
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
+ max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
+ remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
+
+ for _ in range(min_num_layers):
+ common_inputs["past_key_values"].append(
+ (
+ torch.zeros(decoder_shape),
+ torch.zeros(decoder_shape),
+ torch.zeros(encoder_shape),
+ torch.zeros(encoder_shape),
+ )
+ )
+ # TODO: test this.
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
+ for _ in range(min_num_layers, max_num_layers):
+ common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
+ return common_inputs
+
+ def _generate_dummy_inputs_for_causal_lm(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+ tokenizer, batch_size, seq_length, is_pair, framework
+ )
+
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ num_encoder_layers, _ = self.num_layers
+ num_encoder_attention_heads, _ = self.num_attention_heads
+ past_shape = (
+ batch,
+ num_encoder_attention_heads,
+ past_key_values_length,
+ self._config.hidden_size // num_encoder_attention_heads,
+ )
+
+ mask_dtype = common_inputs["attention_mask"].dtype
+ common_inputs["attention_mask"] = torch.cat(
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+ common_inputs["past_key_values"] = [
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
+ ]
+ return common_inputs
+
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ # Copied from OnnxConfig.generate_dummy_inputs
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(
+ batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
+ )
+
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
+ seq_length = compute_effective_axis_dimension(
+ seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
+ )
+
+ # Generate dummy inputs according to compute batch and sequence
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
+ return common_inputs
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ if self.task in ["default", "seq2seq-lm"]:
+ common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ elif self.task == "causal-lm":
+ common_inputs = self._generate_dummy_inputs_for_causal_lm(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+ else:
+ common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ return common_inputs
+
+ def _flatten_past_key_values_(self, flattened_output, name, idx, t):
+ if self.task in ["default", "seq2seq-lm"]:
+ flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
+ else:
+ flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
+ flattened_output, name, idx, t
+ )
diff --git a/unimernet/models/unimernet/configuration_unimernet_encoder.py b/unimernet/models/unimernet/configuration_unimernet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7bd531d9d4a9b4fd323dee58ec5462d48c17168
--- /dev/null
+++ b/unimernet/models/unimernet/configuration_unimernet_encoder.py
@@ -0,0 +1,132 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Donut Swin Transformer model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class UnimerNetConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`UnimerNetModel`]. It is used to instantiate a
+ Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Donut
+ [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 7):
+ Size of windows.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to add absolute position embeddings to the patch embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import UnimerNetConfig, UnimerNetModel
+
+ >>> # Initializing a Donut naver-clova-ix/donut-base style configuration
+ >>> configuration = UnimerNetConfig()
+
+ >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration
+ >>> model = UnimerNetModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "donut-swin"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
diff --git a/unimernet/models/unimernet/encoder_decoder.py b/unimernet/models/unimernet/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf350afdea954ceeb513474b4d3bc861487682d
--- /dev/null
+++ b/unimernet/models/unimernet/encoder_decoder.py
@@ -0,0 +1,843 @@
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ftfy import fix_text
+from torch.nn import CrossEntropyLoss
+from typing import Optional, Tuple, Union, List
+from dataclasses import dataclass
+import math
+
+from transformers import PreTrainedTokenizerFast
+from transformers import VisionEncoderDecoderConfig
+from transformers import AutoModel, VisionEncoderDecoderModel, AutoImageProcessor, MBartForCausalLM
+from unimernet.models.unimernet.processor import VariableDonutProcessor, VariableDonutImageProcessor
+# from transformers.models.mbart.modeling_mbart import MBartDecoder
+from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput, CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
+# from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, DonutSwinEncoder
+from transformers.utils import logging, ModelOutput
+
+from functools import partial
+from .configuration_unimernet_encoder import UnimerNetConfig
+
+from .modeling_unimernet_encoder import UnimerNetPatchEmbeddings, UnimerNetEmbeddings, UnimerNetModel, UnimerNetEncoder
+from .modeling_unimernet_decoder import MBartDecoder
+
+logger = logging.get_logger(__name__)
+
+
+class VariableUnimerNetConfig(UnimerNetConfig):
+ pass
+
+
+def build_norm_layer(dim,
+ norm_layer,):
+ layers = []
+ if norm_layer == 'BN':
+ layers.append(nn.BatchNorm2d(dim))
+ else:
+ raise NotImplementedError(
+ f'build_norm_layer does not support {norm_layer}')
+ return nn.Sequential(*layers)
+
+class StemLayer(nn.Module):
+ r""" Stem layer of InternImage
+ Args:
+ in_chans (int): number of input channels
+ out_chans (int): number of output channels
+ act_layer (str): activation layer
+ norm_layer (str): normalization layer
+ """
+
+ def __init__(self,
+ in_chans=3,
+ out_chans=96,
+ act_layer=nn.GELU,
+ norm_layer='BN'):
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_chans,
+ out_chans // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ self.norm1 = build_norm_layer(out_chans // 2, norm_layer)
+
+ self.act = act_layer()
+ self.conv2 = nn.Conv2d(out_chans // 2,
+ out_chans,
+ kernel_size=3,
+ stride=2,
+ padding=1)
+
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.act(x)
+ x = self.conv2(x)
+ return x
+
+
+class VariableUnimerNetPatchEmbeddings(UnimerNetPatchEmbeddings):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ print("VariableUnimerNetPatchEmbeddings init")
+ super().__init__(config)
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ self.projection = StemLayer(in_chans=num_channels, out_chans=hidden_size)
+
+
+
+
+class VariableUnimerNetEmbeddings(UnimerNetEmbeddings):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config, use_mask_token=False):
+ super().__init__(config, use_mask_token)
+
+ self.patch_embeddings = VariableUnimerNetPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+ self.position_embeddings = None
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+
+ self.row_embeddings = None
+ self.column_embeddings = None
+ if config.use_2d_embeddings:
+ self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
+ self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False,
+ ) -> Tuple[torch.Tensor]:
+ # print('before pixel_values.shape',pixel_values.shape)
+
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+
+ # print('after embeddings.shape',embeddings.shape)
+
+ # Layernorm across the last dimension (each patch is a single row)
+ embeddings = self.norm(embeddings)
+ batch_size, seq_len, embed_dim = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
+
+ if self.row_embeddings is not None and self.column_embeddings is not None:
+ # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
+ row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1],
+ dim=1)
+ column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
+
+ embeddings = embeddings + row_embeddings + column_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+class VariableUnimerNetModel(UnimerNetModel):
+ config_class = VariableUnimerNetConfig
+
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+ print("VariableUnimerNetModel init")
+ super().__init__(config)
+
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = VariableUnimerNetEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = UnimerNetEncoder(config, self.embeddings.patch_grid)
+
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+ """
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ counting: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class CustomMBartDecoder(MBartDecoder):
+ def __init__(self, config):
+ print("CustomMBartDecoder init")
+ super().__init__(config)
+ hidden_size = config.d_model
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ count_pred: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self._use_flash_attention_2:
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+ # TODO: add counting context weight to hidden_states
+ if count_pred is not None:
+ count_context_weight = self.counting_context_weight(count_pred)
+ hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {attn_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class CustomMBartForCausalLM(MBartForCausalLM):
+ def __init__(self, config):
+ print("CustomMBartForCausalLM init")
+ super().__init__(config)
+ # Modify the decoder within MBartDecoderWrapper
+ self.model.decoder = CustomMBartDecoder(config)
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ count_gt: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MBartForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+ >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+ >>> list(logits.shape) == expected_shape
+ True
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ count_pred = None
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ count_pred=count_pred,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits, count_pred) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentionsAndCounting(
+ loss=loss,
+ logits=logits,
+ counting=count_pred,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+class CustomVisionEncoderDecoderModel(VisionEncoderDecoderModel):
+ def __init__(self, config):
+ print("CustomVisionEncoderDecoderModel init")
+ super().__init__(config)
+ # Replace the MBartForCausalLM with your CustomMBartForCausalLM
+ self.encoder = VariableUnimerNetModel(config.encoder)
+ self.decoder = CustomMBartForCausalLM(self.config.decoder)
+
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, VisionEncoderDecoderModel
+ >>> import requests
+ >>> from PIL import Image
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/trocr-base-handwritten")
+ >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
+
+ >>> # load image from the IAM dataset
+ >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+ >>> # training
+ >>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
+ >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
+ >>> model.config.vocab_size = model.config.decoder.vocab_size
+
+ >>> pixel_values = processor(image, return_tensors="pt").pixel_values
+ >>> text = "hello world"
+ >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
+ >>> outputs = model(pixel_values=pixel_values, labels=labels)
+ >>> loss = outputs.loss
+
+ >>> # inference (generation)
+ >>> generated_ids = model.generate(pixel_values)
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ if encoder_outputs is None:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ encoder_outputs = self.encoder(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ # else:
+ encoder_attention_mask = None
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ **kwargs_decoder,
+ )
+
+ # Compute loss independent from decoder (as some shift the logits inside them)
+ loss = None
+ if labels is not None:
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
+ count_gt = kwargs_decoder.get("count_gt", None)
+
+
+
+ if not return_dict:
+ if loss is not None:
+ return (loss,) + decoder_outputs + encoder_outputs
+ else:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class SelfAttentionBlock(nn.Module):
+ def __init__(self, embed_size, num_heads):
+ super(SelfAttentionBlock, self).__init__()
+ self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
+ self.norm = nn.LayerNorm(embed_size)
+
+ def forward(self, x):
+ attn_output, _ = self.self_attention(x, x, x)
+ x = self.norm(attn_output + x)
+ return x
+
+
+class DonutEncoderDecoder(nn.Module):
+
+ def __init__(self, model_name, num_tokens, pad_token_id, bos_token_id, eos_token_id):
+ super().__init__()
+ config = VisionEncoderDecoderConfig.from_pretrained(model_name)
+ encoder_config = vars(config.encoder)
+ encoder = VariableUnimerNetConfig(**encoder_config)
+ config.encoder = encoder
+ self.config = config
+
+ AutoModel.register(VariableUnimerNetConfig, VariableUnimerNetModel)
+
+ self.model = CustomVisionEncoderDecoderModel(config=self.config)
+
+ self.model.config.decoder_start_token_id = bos_token_id
+ self.model.config.pad_token_id = pad_token_id
+ self.model.config.eos_token_id = eos_token_id
+ self.model.decoder.resize_token_embeddings(num_tokens)
+ self.pad_token_id = pad_token_id
+
+ def forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, **kwargs):
+ num_channels = pixel_values.shape[1]
+ if num_channels == 1:
+ pixel_values = pixel_values.repeat(1, 3, 1, 1)
+
+ labels = decoder_input_ids * 1
+ labels = labels.masked_fill(labels == self.pad_token_id, -100)
+
+ loss = self.model(
+ pixel_values=pixel_values,
+ decoder_input_ids=decoder_input_ids[:, :-1],
+ decoder_attention_mask=decoder_attention_mask[:, :-1],
+ labels=labels[:, 1:],
+ **kwargs
+ ).loss
+ return loss
+
+ @torch.no_grad()
+ def generate(self, pixel_values, temperature, max_new_tokens, decoder_start_token_id, do_sample, top_p,
+ **kwargs):
+
+ num_channels = pixel_values.shape[1]
+ if num_channels == 1:
+ pixel_values = pixel_values.repeat(1, 3, 1, 1)
+ outputs = self.model.generate(
+ pixel_values=pixel_values,
+ max_new_tokens=max_new_tokens,
+ decoder_start_token_id=decoder_start_token_id,
+ temperature=temperature,
+ do_sample=do_sample,
+ top_p=top_p,
+ )
+ return outputs[:, 1:]
+
+
+
+class DonutTokenizer:
+ def __init__(self, path):
+ AutoImageProcessor.register(VariableUnimerNetConfig, VariableDonutImageProcessor)
+ processor = VariableDonutProcessor.from_pretrained(path)
+ processor.train = False
+ self.tokenizer = processor.tokenizer
+ self.max_seq_len = 2048
+ self.pad_token_id = self.tokenizer.pad_token_id
+ self.bos_token_id = self.tokenizer.bos_token_id
+ self.eos_token_id = self.tokenizer.eos_token_id
+
+ def __len__(self):
+ return len(self.tokenizer)
+
+ def tokenize(self, texts, max_length=None):
+ if not max_length:
+ max_length = self.max_seq_len
+ text_inputs = self.tokenizer(
+ texts,
+ return_token_type_ids=False,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=max_length,
+ )
+ return text_inputs
+
+ @staticmethod
+ def post_process(text):
+ text = fix_text(text)
+ return text
+
+ def token2str(self, tokens) -> list:
+ generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
+ generated_text = [self.post_process(text) for text in generated_text]
+ return generated_text
+
+ def detokenize(self, tokens):
+ toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
+ for b in range(len(toks)):
+ for i in reversed(range(len(toks[b]))):
+ if toks[b][i] is None:
+ toks[b][i] = ''
+ toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
+ if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
+ del toks[b][i]
+ return toks
diff --git a/unimernet/models/unimernet/modeling_unimernet_decoder.py b/unimernet/models/unimernet/modeling_unimernet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42d46175cb8a228656cc3421a55029f9bd19575
--- /dev/null
+++ b/unimernet/models/unimernet/modeling_unimernet_decoder.py
@@ -0,0 +1,2158 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MBART model."""
+
+import copy
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ Seq2SeqQuestionAnsweringModelOutput,
+ Seq2SeqSequenceClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_unimernet_decoder import MBartConfig
+
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
+_CONFIG_FOR_DOC = "MBartConfig"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
+ """
+ Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not
+ have a single `decoder_start_token_id` in contrast to other Bart-like models.
+ """
+ prev_output_tokens = input_ids.clone()
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
+
+ index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
+ decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
+ prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
+ prev_output_tokens[:, 0] = decoder_start_tokens
+
+ return prev_output_tokens
+
+
+# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
+class MBartLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+ bsz, seq_len = input_ids.shape[:2]
+ positions = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ ).expand(bsz, -1)
+
+ return super().forward(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
+class MBartScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
+class MBartSqueezeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper, with qk_squeeze"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ qk_squeeze: int = 2,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[MBartConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+
+ self.squeeze_dim = embed_dim // qk_squeeze
+ self.squeeze_head_dim = self.squeeze_dim // num_heads
+ self.scaling = self.squeeze_head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
+
+ def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
+ value_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*value_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
+class MBartFlashAttention2(MBartSqueezeAttention):
+ """
+ MBart flash attention module. This module inherits from `MBartSqueezeAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # MBartFlashAttention2 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("MBartFlashAttention2 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+MBART_ATTENTION_CLASSES = {
+ "eager": MBartSqueezeAttention,
+ "flash_attention_2": MBartFlashAttention2,
+}
+
+
+class MBartEncoderLayer(nn.Module):
+ def __init__(self, config: MBartConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class MBartDecoderLayer(nn.Module):
+ def __init__(self, config: MBartConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MBart
+class MBartClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(
+ self,
+ input_dim: int,
+ inner_dim: int,
+ num_classes: int,
+ pooler_dropout: float,
+ ):
+ super().__init__()
+ self.dense = nn.Linear(input_dim, inner_dim)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+ self.out_proj = nn.Linear(inner_dim, num_classes)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.out_proj(hidden_states)
+ return hidden_states
+
+
+class MBartPreTrainedModel(PreTrainedModel):
+ config_class = MBartConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MBartDecoderLayer", "MBartSqueezeAttention"]
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+MBART_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MBartConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MBART_GENERATION_EXAMPLE = r"""
+ Translation example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+ >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
+
+ >>> example_english_phrase = "42 is the answer"
+ >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
+
+ >>> # Translate
+ >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
+ >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ '42 este răspuns'
+ ```
+
+ Mask filling example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+ >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+
+ >>> # de_DE is the language symbol id for German
+ >>> TXT = " Meine Freunde sind nett aber sie essen zu viel Kuchen. de_DE"
+
+ >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = logits[0, masked_index].softmax(dim=0)
+ >>> values, predictions = probs.topk(5)
+
+ >>> tokenizer.decode(predictions).split()
+ ['nett', 'sehr', 'ganz', 'nicht', 'so']
+ ```
+"""
+
+MBART_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+ varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+ for denoising pre-training following the paper.
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MBartEncoder(MBartPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`MBartEncoderLayer`].
+
+ Args:
+ config: MBartConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ self.embed_tokens = MBartScaledWordEmbedding(
+ config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = MBartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _backward_compatibility_gradient_checkpointing(self):
+ # Override to not delete the attribute from the config
+ if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
+ self.gradient_checkpointing_enable()
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ embed_pos = self.embed_positions(input)
+
+ hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ if self._use_flash_attention_2:
+ attention_mask = attention_mask if 0 in attention_mask else None
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class MBartDecoder(MBartPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
+
+ Args:
+ config: MBartConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ self.embed_tokens = MBartScaledWordEmbedding(
+ config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = MBartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+ self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self._use_flash_attention_2:
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {attn_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare MBART Model outputting raw hidden-states without any specific head on top.",
+ MBART_START_DOCSTRING,
+)
+class MBartModel(MBartPreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: MBartConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = MBartEncoder(config, self.shared)
+ self.decoder = MBartDecoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
+
+ @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Seq2SeqModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # different to other models, MBart automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.",
+ MBART_START_DOCSTRING,
+)
+class MBartForConditionalGeneration(MBartPreTrainedModel):
+ base_model_prefix = "model"
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: MBartConfig):
+ super().__init__(config)
+ self.model = MBartModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ @add_end_docstrings(MBART_GENERATION_EXAMPLE)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
+ tasks.
+ """,
+ MBART_START_DOCSTRING,
+)
+class MBartForSequenceClassification(MBartPreTrainedModel):
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+ def __init__(self, config: MBartConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.model = MBartModel(config)
+ self.classification_head = MBartClassificationHead(
+ config.d_model,
+ config.d_model,
+ config.num_labels,
+ config.classifier_dropout,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Seq2SeqSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ if input_ids is None and inputs_embeds is not None:
+ raise NotImplementedError(
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0] # last hidden state
+
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
+
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
+ raise ValueError("All examples must have the same number of tokens.")
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
+ :, -1, :
+ ]
+ logits = self.classification_head(sentence_representation)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.config.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.config.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ MBART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ MBART_START_DOCSTRING,
+)
+class MBartForQuestionAnswering(MBartPreTrainedModel):
+ _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ config.num_labels = 2
+ self.num_labels = config.num_labels
+
+ self.model = MBartModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Seq2SeqQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
+ def forward(
+ self,
+ input_ids: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if start_positions is not None and end_positions is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return Seq2SeqQuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
+class MBartDecoderWrapper(MBartPreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = MBartDecoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
+class MBartForCausalLM(MBartPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.model = MBartDecoderWrapper(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MBartForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+ >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+ >>> list(logits.shape) == expected_shape
+ True
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past_key_values:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/unimernet/models/unimernet/modeling_unimernet_encoder.py b/unimernet/models/unimernet/modeling_unimernet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8653071bdcb7e7460c1ec27a268b28409b2fa7
--- /dev/null
+++ b/unimernet/models/unimernet/modeling_unimernet_encoder.py
@@ -0,0 +1,1035 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UnimerNet Transformer model.
+
+This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
+states."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ torch_int,
+)
+from .configuration_unimernet_encoder import UnimerNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "UnimerNetConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->UnimerNet
+class UnimerNetEncoderOutput(ModelOutput):
+ """
+ UnimerNet encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->UnimerNet
+class UnimerNetModelOutput(ModelOutput):
+ """
+ UnimerNet model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = input_feature.shape
+ input_feature = input_feature.view(
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+ )
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+ """
+ Merges windows to produce higher resolution features.
+ """
+ num_channels = windows.shape[-1]
+ windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->UnimerNet
+class UnimerNetEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config, use_mask_token=False):
+ super().__init__()
+
+ self.patch_embeddings = UnimerNetPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+ else:
+ self.position_embeddings = None
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Tuple[torch.Tensor]:
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+ batch_size, seq_len, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->UnimerNet
+class UnimerNetPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def maybe_pad(self, pixel_values, height, width):
+ if width % self.patch_size[1] != 0:
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ _, num_channels, height, width = pixel_values.shape
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+ embeddings = self.projection(pixel_values)
+ _, _, height, width = embeddings.shape
+ output_dimensions = (height, width)
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class UnimerNetPatchMerging(nn.Module):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def maybe_pad(self, input_feature, height, width):
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
+ input_feature = nn.functional.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, dim, num_channels = input_feature.shape
+
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature)
+ input_feature = self.reduction(input_feature)
+
+ return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath
+class UnimerNetDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->UnimerNet
+class UnimerNetSelfAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ batch_size, dim, num_channels = hidden_states.shape
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+ relative_position_bias = relative_position_bias.view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ )
+
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+ attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in UnimerNetModel forward() function)
+ mask_shape = attention_mask.shape[0]
+ attention_scores = attention_scores.view(
+ batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+ )
+ attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
+class UnimerNetSelfOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->UnimerNet
+class UnimerNetAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ self.self = UnimerNetSelfAttention(config, dim, num_heads, window_size)
+ self.output = UnimerNetSelfOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
+class UnimerNetIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput
+class UnimerNetOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class ConvEnhance(nn.Module):
+ """Depth-wise convolution to get the positional information.
+ """
+ def __init__(self, config, dim, k=3):
+ super(ConvEnhance, self).__init__()
+ self.proj = nn.Conv2d(dim,
+ dim,
+ (k,k),
+ (1,1),
+ (k // 2,k // 2),
+ groups=dim)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+ H, W = size
+ assert N == H * W
+
+ feat = x.transpose(1, 2).view(B, C, H, W)
+ feat = self.proj(feat)
+ feat = self.act_fn(feat)
+ feat = feat.flatten(2).transpose(1, 2)
+
+ x = x + feat
+ return x
+
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->UnimerNet
+class UnimerNetLayer(nn.Module):
+ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+
+ self.ce = nn.ModuleList([ConvEnhance(config, dim=dim, k=3),
+ ConvEnhance(config, dim=dim, k=3)])
+
+ self.attention = UnimerNetAttention(config, dim, num_heads, window_size=self.window_size)
+ self.drop_path = UnimerNetDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = UnimerNetIntermediate(config, dim)
+ self.output = UnimerNetOutput(config, dim)
+
+ def set_shift_and_window_size(self, input_resolution):
+ if min(input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = torch_int(0)
+ self.window_size = (
+ torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
+ )
+
+ def get_attn_mask(self, height, width, dtype, device):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
+ height_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ width_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ img_mask[:, height_slice, width_slice, :] = count
+ count += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states, height, width):
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
+ height, width = input_dimensions
+ batch_size, _, channels = hidden_states.size()
+
+
+
+ hidden_states = self.ce[0](hidden_states, input_dimensions)
+ shortcut = hidden_states
+
+
+ hidden_states = self.layernorm_before(hidden_states)
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
+
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+ attn_mask = self.get_attn_mask(
+ height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
+ )
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+ hidden_states = shortcut + self.drop_path(attention_windows)
+
+
+
+ hidden_states = self.ce[1](hidden_states, input_dimensions)
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.output(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->UnimerNet
+class UnimerNetStage(nn.Module):
+ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.blocks = nn.ModuleList(
+ [
+ UnimerNetLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ shift_size=0,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
+
+ hidden_states = layer_outputs[0]
+
+ hidden_states_before_downsampling = hidden_states
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->UnimerNet
+class UnimerNetEncoder(nn.Module):
+ def __init__(self, config, grid_size):
+ super().__init__()
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
+ self.layers = nn.ModuleList(
+ [
+ UnimerNetStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=UnimerNetPatchMerging if (i_layer < self.num_layers - 1) else None,
+ )
+ for i_layer in range(self.num_layers)
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ output_hidden_states_before_downsampling: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, UnimerNetEncoderOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ input_dimensions,
+ layer_head_mask,
+ output_attentions,
+ always_partition,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
+
+ hidden_states = layer_outputs[0]
+ hidden_states_before_downsampling = layer_outputs[1]
+ output_dimensions = layer_outputs[2]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+
+ if output_hidden_states and output_hidden_states_before_downsampling:
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
+ # rearrange b (h w) c -> b c h w
+ # here we use the original (not downsampled) height and width
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
+ )
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states_before_downsampling,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[3:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return UnimerNetEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->UnimerNet
+class UnimerNetPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = UnimerNetConfig
+ base_model_prefix = "swin"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["UnimerNetStage"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+SWIN_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`UnimerNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`DonutImageProcessor.__call__`] for details.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare UnimerNet Model transformer outputting raw hidden-states without any specific head on top.",
+ SWIN_START_DOCSTRING,
+)
+class UnimerNetModel(UnimerNetPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+ super().__init__(config)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = UnimerNetEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = UnimerNetEncoder(config, self.embeddings.patch_grid)
+
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=UnimerNetModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UnimerNetModelOutput]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return UnimerNetModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
diff --git a/unimernet/models/unimernet/processor.py b/unimernet/models/unimernet/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dce532e08cb6c671c500c33464a283e2a178afe
--- /dev/null
+++ b/unimernet/models/unimernet/processor.py
@@ -0,0 +1,192 @@
+from typing import Dict, Union, Optional, List
+
+from torch import TensorType
+from transformers import DonutImageProcessor, DonutProcessor
+from transformers.image_processing_utils import BatchFeature
+from transformers.image_transforms import pad
+from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
+ valid_images, to_numpy_array, is_scaled_image, get_image_size
+import numpy as np
+import PIL
+import logging
+
+logger = logging.getLogger()
+
+IMAGE_STD = [0.229, 0.224, 0.225]
+IMAGE_MEAN = [0.485, 0.456, 0.406]
+
+
+class VariableDonutImageProcessor(DonutImageProcessor):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def numpy_resize(self, image: np.ndarray, size, resample):
+ image = PIL.Image.fromarray(image)
+ resized = self.pil_resize(image, size, resample)
+ resized = np.array(resized, dtype=np.uint8)
+ resized_image = resized.transpose(2, 0, 1)
+
+ return resized_image
+
+ def pil_resize(self, image: PIL.Image.Image, size, resample):
+ width, height = image.size
+ max_width, max_height = size["width"], size["height"]
+ if width != max_width or height != max_height:
+ # Shrink to fit within dimensions
+ width_scale = max_width / width
+ height_scale = max_height / height
+ scale = min(width_scale, height_scale)
+
+ new_width = min(int(width * scale), max_width)
+ new_height = min(int(height * scale), max_height)
+
+ image = image.resize((new_width, new_height), resample)
+
+ image.thumbnail((max_width, max_height), resample)
+
+ assert image.width <= max_width and image.height <= max_height
+
+ return image
+
+ def process_inner(self, images: List[List], train=False):
+ # This will be in list of lists format, with height x width x channel
+ assert isinstance(images[0], (list, np.ndarray))
+
+ # convert list of lists format to array
+ if isinstance(images[0], list):
+ # numpy unit8 needed for augmentation
+ np_images = [np.array(img, dtype=np.uint8) for img in images]
+ else:
+ np_images = [img.astype(np.uint8) for img in images]
+
+ assert np_images[0].shape[2] == 3 # RGB input images, channel dim last
+
+ # This also applies the right channel dim format, to channel x height x width
+ np_images = [self.numpy_resize(img, self.max_size, self.resample) for img in np_images]
+ assert np_images[0].shape[0] == 3 # RGB input images, channel dim first
+
+ # Convert to float32 for rescale/normalize
+ np_images = [img.astype(np.float32) for img in np_images]
+
+ # Pads with 255 (whitespace)
+ # Pad to max size to improve performance
+ max_size = self.max_size
+ np_images = [
+ self.pad_image(
+ image=image,
+ size=max_size,
+ random_padding=train, # Change amount of padding randomly during training
+ input_data_format=ChannelDimension.FIRST,
+ pad_value=255.0
+ )
+ for image in np_images
+ ]
+
+ # Rescale and normalize
+ np_images = [
+ self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
+ for img in np_images
+ ]
+ np_images = [
+ self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
+ for img in np_images
+ ]
+
+ return np_images
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_thumbnail: bool = None,
+ do_align_long_axis: bool = None,
+ do_pad: bool = None,
+ random_padding: bool = False,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> PIL.Image.Image:
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # Convert to numpy for later processing steps
+ images = [to_numpy_array(image) for image in images]
+
+ images = self.process_inner(images, train=False)
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ random_padding: bool = False,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_value: float = 0.0,
+ ) -> np.ndarray:
+ output_height, output_width = size["height"], size["width"]
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+
+ delta_width = output_width - input_width
+ delta_height = output_height - input_height
+
+ assert delta_width >= 0 and delta_height >= 0
+
+ if random_padding:
+ pad_top = np.random.randint(low=0, high=delta_height + 1)
+ pad_left = np.random.randint(low=0, high=delta_width + 1)
+ else:
+ pad_top = delta_height // 2
+ pad_left = delta_width // 2
+
+ pad_bottom = delta_height - pad_top
+ pad_right = delta_width - pad_left
+
+ padding = ((pad_top, pad_bottom), (pad_left, pad_right))
+ return pad(image, padding, data_format=data_format, input_data_format=input_data_format,
+ constant_values=pad_value)
+
+
+class VariableDonutProcessor(DonutProcessor):
+ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+
+ super().__init__(image_processor, tokenizer)
+ self.current_processor = self.image_processor
+ self._in_target_context_manager = False
+ self.train = train
+
+ def __call__(self, *args, **kwargs):
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ images = kwargs.pop("images", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ images = args[0]
+ args = args[1:]
+
+ if images is None:
+ raise ValueError("You need to specify images to process.")
+
+ inputs = self.image_processor(images, *args, **kwargs)
+ return inputs
diff --git a/unimernet/models/unimernet/unimernet.py b/unimernet/models/unimernet/unimernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f224ae449aaf99b250b0507401c39a40ab233ddd
--- /dev/null
+++ b/unimernet/models/unimernet/unimernet.py
@@ -0,0 +1,110 @@
+import torch
+import torch.nn.functional as F
+from unimernet.common.registry import registry
+from unimernet.models.blip2_models.blip2 import Blip2Base
+from unimernet.models.unimernet.encoder_decoder import DonutEncoderDecoder, DonutTokenizer
+
+
+@registry.register_model("unimernet")
+class UniMERModel(Blip2Base):
+ """
+ Nougat model for formula recognition.
+ Supported model types:
+ - default
+ Usage:
+ >>> from unimernet.models import load_model
+ >>> model = load_model("unimernet", "default")
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "default": "configs/models/unimernet_base.yaml",
+ "unimernet": "configs/models/unimernet_base.yaml",
+ }
+
+ def __init__(
+ self,
+ *,
+ model_name,
+ model_config,
+ tokenizer_name,
+ tokenizer_config,
+ ):
+ super().__init__()
+
+ self.tokenizer = DonutTokenizer(tokenizer_config.path)
+ self.model = DonutEncoderDecoder(
+ model_config.model_name,
+ num_tokens=len(self.tokenizer),
+ bos_token_id=self.tokenizer.bos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+ self.max_seq_len = model_config.max_seq_len
+ self.tokenizer.max_seq_len = self.max_seq_len
+
+ def forward(self, samples):
+ image, text = samples["image"], samples["text_input"]
+
+ text_inputs = self.tokenizer.tokenize(text).to(image.device)
+ count_gt = self._get_count_gt(text, image.device)
+ tgt_seq, tgt_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
+ with self.maybe_autocast():
+ loss = self.model(
+ pixel_values=image,
+ decoder_input_ids=tgt_seq,
+ decoder_attention_mask=tgt_mask,
+ decoder_count_gt=count_gt,
+ )
+ return {"loss": loss}
+
+ def _get_count_gt(self, text, device):
+ labels = self.tokenizer.tokenize(text, max_length=1536)["input_ids"].to(device)
+ mask = labels != self.tokenizer.pad_token_id
+ one_hot_labels = F.one_hot(labels, num_classes=self.tokenizer.tokenizer.vocab_size) * mask.unsqueeze(-1)
+ count_gt = torch.sum(one_hot_labels, dim=1)
+ return count_gt # (bs, vocab_size)
+
+ @torch.no_grad()
+ def generate(
+ self,
+ samples,
+ temperature: float = 0.2,
+ do_sample: bool = False,
+ top_p: float = 0.95,
+ **kwargs
+ ):
+
+ image = samples["image"]
+ with self.maybe_autocast():
+ outputs = self.model.generate(
+ pixel_values=image,
+ temperature=temperature,
+ max_new_tokens=self.max_seq_len,
+ decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
+ # decoder_end_token_id=self.tokenizer.tokenizer.eos_token_id,
+ do_sample=do_sample,
+ top_p=top_p,
+ **kwargs
+ )
+ pred_tokens = self.tokenizer.detokenize(outputs)
+ pred_str = self.tokenizer.token2str(outputs)
+ return {"pred_tokens": pred_tokens, "pred_str": pred_str, "pred_ids": outputs}
+
+ @classmethod
+ def from_config(cls, cfg):
+
+ model_name = cfg.get("model_name")
+ model_config = cfg.get("model_config")
+ tokenizer_name = cfg.get("tokenizer_name")
+ tokenizer_config = cfg.get("tokenizer_config")
+
+ model = cls(
+ model_name=model_name,
+ model_config=model_config,
+ tokenizer_name=tokenizer_name,
+ tokenizer_config=tokenizer_config
+ )
+
+ model.load_checkpoint_from_config(cfg)
+
+ return model
diff --git a/unimernet/models/unimernet/utils.py b/unimernet/models/unimernet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce00f4656c00777f868c7740a480987b2d4d8c40
--- /dev/null
+++ b/unimernet/models/unimernet/utils.py
@@ -0,0 +1,55 @@
+import torch
+import torch.nn as nn
+
+from . import hybrid
+from . import vit
+from . import transformer
+
+
+class Model(nn.Module):
+ def __init__(self, encoder, decoder, args):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.args = args
+
+ def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs):
+ if not device_ids or len(device_ids) == 1:
+ return self(x, **kwargs)
+ if output_device is None:
+ output_device = device_ids[0]
+ replicas = nn.parallel.replicate(self, device_ids)
+ inputs = nn.parallel.scatter(x, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
+ kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
+ replicas = replicas[:len(inputs)]
+ kwargs = kwargs[:len(inputs)]
+ outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
+ return nn.parallel.gather(outputs, output_device).mean()
+
+ def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor, **kwargs):
+ encoded = self.encoder(x)
+ out = self.decoder(tgt_seq, context=encoded, **kwargs)
+ return out
+
+ @torch.no_grad()
+ def generate(self, x: torch.Tensor, temperature: float = 0.25):
+ return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), self.args.max_seq_len,
+ eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)
+
+
+def get_model(args):
+ if args.encoder_structure.lower() == 'vit':
+ encoder = vit.get_encoder(args)
+ elif args.encoder_structure.lower() == 'hybrid':
+ encoder = hybrid.get_encoder(args)
+ else:
+ raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure)
+ decoder = transformer.get_decoder(args)
+ encoder.to(args.device)
+ decoder.to(args.device)
+ model = Model(encoder, decoder, args)
+ if args.wandb:
+ import wandb
+ wandb.watch(model)
+
+ return model
\ No newline at end of file
diff --git a/unimernet/models/vit.py b/unimernet/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b2c4de691f98c1d0ee2a921fa0aaf8ccc9cdfb
--- /dev/null
+++ b/unimernet/models/vit.py
@@ -0,0 +1,527 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+ Based on timm code base
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
+"""
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+from unimernet.models.base_model import BaseEncoder
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ use_grad_checkpointing=False,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ if use_grad_checkpointing:
+ self.attn = checkpoint_wrapper(self.attn)
+ self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ representation_size=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=None,
+ use_grad_checkpointing=False,
+ ckpt_layer=0,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = (
+ self.embed_dim
+ ) = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ use_grad_checkpointing=(
+ use_grad_checkpointing and i >= depth - ckpt_layer
+ ),
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=0.02)
+ trunc_normal_(self.cls_token, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:, : x.size(1), :]
+ x = self.pos_drop(x)
+
+ for i, blk in enumerate(self.blocks):
+ x = blk(x, register_blk == i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=""):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and "opt/target/embedding/kernel" in w:
+ prefix = "opt/target/"
+
+ if hasattr(model.patch_embed, "backbone"):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, "stem")
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(
+ adapt_input_conv(
+ stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
+ )
+ )
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
+ for r in range(3):
+ getattr(block, f"conv{r + 1}").weight.copy_(
+ _n2p(w[f"{bp}conv{r + 1}/kernel"])
+ )
+ getattr(block, f"norm{r + 1}").weight.copy_(
+ _n2p(w[f"{bp}gn{r + 1}/scale"])
+ )
+ getattr(block, f"norm{r + 1}").bias.copy_(
+ _n2p(w[f"{bp}gn{r + 1}/bias"])
+ )
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(
+ _n2p(w[f"{bp}conv_proj/kernel"])
+ )
+ block.downsample.norm.weight.copy_(
+ _n2p(w[f"{bp}gn_proj/scale"])
+ )
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
+ )
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w,
+ model.pos_embed,
+ getattr(model, "num_tokens", 1),
+ model.patch_embed.grid_size,
+ )
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
+ block.attn.qkv.weight.copy_(
+ torch.cat(
+ [
+ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
+ for n in ("query", "key", "value")
+ ]
+ )
+ )
+ block.attn.qkv.bias.copy_(
+ torch.cat(
+ [
+ _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
+ for n in ("query", "key", "value")
+ ]
+ )
+ )
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
+ for r in range(2):
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
+ )
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(
+ _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
+ )
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
+
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+ assert len(gs_new) >= 2
+ print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(
+ posemb_grid, size=gs_new, mode="bicubic", align_corners=False
+ )
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+
+ if orig_size != new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print(
+ "reshape position embedding from %d to %d" % (orig_size**2, new_size**2)
+ )
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
+
+
+class VisionTransformerEncoder(VisionTransformer, BaseEncoder):
+ @classmethod
+ def from_config(cls, cfg, from_pretrained=False):
+
+ vit_type = cfg.get("vit_type", "base")
+ image_size = cfg.get("image_size", 384)
+ ckpt_layer = cfg.get("vit_ckpt_layer", 0)
+ drop_path_rate = cfg.get("vit_drop_path_rate", 0)
+ norm_layer_eps = cfg.get("vit_layer_norm_epsilon", -1)
+ use_grad_checkpointing = cfg.get("vit_grad_ckpt", False)
+
+ if norm_layer_eps == -1:
+ norm_layer = None
+ else:
+ norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)
+
+ # norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ assert vit_type in ["base", "large"], "vit parameter must be base or large"
+ if vit_type == "base":
+ vision_width = 768
+ visual_encoder = cls(
+ img_size=image_size,
+ patch_size=16,
+ embed_dim=vision_width,
+ depth=12,
+ num_heads=12,
+ use_grad_checkpointing=use_grad_checkpointing,
+ ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate,
+ norm_layer=norm_layer,
+ )
+
+ if from_pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu",
+ check_hash=True,
+ )
+ state_dict = checkpoint["model"]
+ state_dict["pos_embed"] = interpolate_pos_embed(
+ state_dict["pos_embed"], visual_encoder
+ )
+ msg = visual_encoder.load_state_dict(state_dict, strict=False)
+
+ elif vit_type == "large":
+ vision_width = 1024
+ visual_encoder = cls(
+ img_size=image_size,
+ patch_size=16,
+ embed_dim=vision_width,
+ depth=24,
+ num_heads=16,
+ use_grad_checkpointing=use_grad_checkpointing,
+ ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate,
+ norm_layer=norm_layer,
+ )
+ if from_pretrained:
+ from timm.models.helpers import load_custom_pretrained
+ from timm.models.vision_transformer import default_cfgs
+
+ load_custom_pretrained(
+ visual_encoder, default_cfgs["vit_large_patch16_224_in21k"]
+ )
+
+ visual_encoder.vision_width = vision_width
+ return visual_encoder
+
+ def forward_features(self, x, register_blk=-1):
+ return super().forward(x, register_blk)
diff --git a/unimernet/processors/__init__.py b/unimernet/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7548d5004f05b9717da6d87b66649ba0dfacda52
--- /dev/null
+++ b/unimernet/processors/__init__.py
@@ -0,0 +1,40 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.processors.base_processor import BaseProcessor
+
+
+from unimernet.processors.blip_processors import (
+ BlipImageTrainProcessor,
+ Blip2ImageTrainProcessor,
+ BlipImageEvalProcessor,
+ BlipCaptionProcessor,
+)
+
+from unimernet.processors.formula_processor import (
+ FormulaImageTrainProcessor,
+ FormulaImageEvalProcessor,
+ FormulaImageMultiScaleTrainProcessor,
+)
+
+from unimernet.common.registry import registry
+
+__all__ = [
+ "BaseProcessor",
+ "BlipCaptionProcessor",
+]
+
+
+def load_processor(name, cfg=None):
+ """
+ Example
+
+ >>> processor = load_processor("alpro_video_train", cfg=None)
+ """
+ processor = registry.get_processor_class(name).from_config(cfg)
+
+ return processor
diff --git a/unimernet/processors/__pycache__/__init__.cpython-310.pyc b/unimernet/processors/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b62669811f31a126af7d41f9bb43c41e34b56b87
Binary files /dev/null and b/unimernet/processors/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/base_processor.cpython-310.pyc b/unimernet/processors/__pycache__/base_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bcb97b26620df45215edeab07fab15192c725e8
Binary files /dev/null and b/unimernet/processors/__pycache__/base_processor.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc b/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b7b83f42465f3a3642eafc3afd563e05ee10e40
Binary files /dev/null and b/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc b/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38eb3428ea69b5ce0192162de3e4ef55c42a1872
Binary files /dev/null and b/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/randaugment.cpython-310.pyc b/unimernet/processors/__pycache__/randaugment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..390161bd14da673b7aa1188926b5d44e3752607b
Binary files /dev/null and b/unimernet/processors/__pycache__/randaugment.cpython-310.pyc differ
diff --git a/unimernet/processors/base_processor.py b/unimernet/processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c9d86859270a046623661a632587f2b3136b46
--- /dev/null
+++ b/unimernet/processors/base_processor.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from omegaconf import OmegaConf
+
+
+class BaseProcessor:
+ def __init__(self):
+ self.transform = lambda x: x
+ return
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ return cls()
+
+ def build(self, **kwargs):
+ cfg = OmegaConf.create(kwargs)
+
+ return self.from_config(cfg)
diff --git a/unimernet/processors/blip_processors.py b/unimernet/processors/blip_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..28d6c4f920a126667bfebf35ef5d6a64b4294fcb
--- /dev/null
+++ b/unimernet/processors/blip_processors.py
@@ -0,0 +1,281 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import re
+
+from unimernet.common.registry import registry
+from unimernet.processors.base_processor import BaseProcessor
+from unimernet.processors.randaugment import RandomAugment
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class BlipImageBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None):
+ if mean is None:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ if std is None:
+ std = (0.26862954, 0.26130258, 0.27577711)
+
+ self.normalize = transforms.Normalize(mean, std)
+
+
+@registry.register_processor("blip_caption")
+class BlipCaptionProcessor(BaseProcessor):
+ def __init__(self, prompt="", max_words=50):
+ self.prompt = prompt
+ self.max_words = max_words
+
+ def __call__(self, caption):
+ caption = self.prompt + self.pre_caption(caption)
+
+ return caption
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ prompt = cfg.get("prompt", "")
+ max_words = cfg.get("max_words", 50)
+
+ return cls(prompt=prompt, max_words=max_words)
+
+ def pre_caption(self, caption):
+ caption = re.sub(
+ r"([.!\"()*#:;~])",
+ " ",
+ caption.lower(),
+ )
+ caption = re.sub(
+ r"\s{2,}",
+ " ",
+ caption,
+ )
+ caption = caption.rstrip("\n")
+ caption = caption.strip(" ")
+
+ # truncate caption
+ caption_words = caption.split(" ")
+ if len(caption_words) > self.max_words:
+ caption = " ".join(caption_words[: self.max_words])
+
+ return caption
+
+@registry.register_processor("blip_caption_instruct")
+class BlipCaptionInstructProcessor(BaseProcessor):
+ def __init__(self, prompt="", max_words=256):
+ self.prompt = prompt
+ self.max_words = max_words
+
+ def __call__(self, caption):
+ caption = self.prompt + self.pre_caption(caption)
+
+ return caption
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ prompt = cfg.get("prompt", "")
+ max_words = cfg.get("max_words", 256)
+
+ return cls(prompt=prompt, max_words=max_words)
+
+ def pre_caption(self, caption):
+ # caption = re.sub(
+ # r"([.!\"()*#:;~])",
+ # " ",
+ # caption.lower(),
+ # )
+ # caption = re.sub(
+ # r"\s{2,}",
+ # " ",
+ # caption,
+ # )
+ caption = caption.rstrip("\n")
+ caption = caption.strip(" ")
+
+ # # truncate caption
+ # caption_words = caption.split(" ")
+ # if len(caption_words) > self.max_words:
+ # caption = " ".join(caption_words[: self.max_words])
+
+ return caption
+
+
+@registry.register_processor("blip_question")
+class BlipQuestionProcessor(BaseProcessor):
+ def __init__(self, max_words=50):
+ self.max_words = max_words
+
+ def __call__(self, question):
+ return self.pre_question(question)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ max_words = cfg.get("max_words", 50)
+
+ return cls(max_words=max_words)
+
+ def pre_question(self, question):
+ question = re.sub(
+ r"([.!\"()*#:;~])",
+ "",
+ question.lower(),
+ )
+ question = question.rstrip(" ")
+
+ # truncate question
+ question_words = question.split(" ")
+ if len(question_words) > self.max_words:
+ question = " ".join(question_words[: self.max_words])
+
+ return question
+
+
+@registry.register_processor("blip_image_train")
+class BlipImageTrainProcessor(BlipImageBaseProcessor):
+ def __init__(
+ self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
+ ):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ image_size,
+ scale=(min_scale, max_scale),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.RandomHorizontalFlip(),
+ RandomAugment(
+ 2,
+ 5,
+ isPIL=True,
+ augs=[
+ "Identity",
+ "AutoContrast",
+ "Brightness",
+ "Sharpness",
+ "Equalize",
+ "ShearX",
+ "ShearY",
+ "TranslateX",
+ "TranslateY",
+ "Rotate",
+ ],
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 384)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ min_scale = cfg.get("min_scale", 0.5)
+ max_scale = cfg.get("max_scale", 1.0)
+
+ return cls(
+ image_size=image_size,
+ mean=mean,
+ std=std,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+
+@registry.register_processor("blip_image_eval")
+class BlipImageEvalProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=384, mean=None, std=None):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 384)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ return cls(image_size=image_size, mean=mean, std=std)
+
+
+@registry.register_processor("blip2_image_train")
+class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
+ def __init__(
+ self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
+ ):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ image_size,
+ scale=(min_scale, max_scale),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 364)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ min_scale = cfg.get("min_scale", 0.5)
+ max_scale = cfg.get("max_scale", 1.0)
+
+ return cls(
+ image_size=image_size,
+ mean=mean,
+ std=std,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
\ No newline at end of file
diff --git a/unimernet/processors/formula_processor.py b/unimernet/processors/formula_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb12b2f9e5c1cc1f68f82ab008c75f8ccc16fb8
--- /dev/null
+++ b/unimernet/processors/formula_processor.py
@@ -0,0 +1,171 @@
+from unimernet.common.registry import registry
+from omegaconf import OmegaConf
+import albumentations as alb
+from albumentations.pytorch import ToTensorV2
+from unimernet.processors.base_processor import BaseProcessor
+import numpy as np
+import cv2
+from PIL import Image, ImageOps
+from torchvision.transforms.functional import resize
+import random
+from unimernet.processors.formula_processor_helper.nougat import Bitmap, Dilation, Erosion
+from unimernet.processors.formula_processor_helper.weather import Fog, Frost, Snow, Rain, Shadow
+
+
+class FormulaImageBaseProcessor(BaseProcessor):
+
+ def __init__(self, image_size):
+ super(FormulaImageBaseProcessor, self).__init__()
+ self.input_size = [int(_) for _ in image_size]
+ assert len(self.input_size) == 2
+
+ @staticmethod
+ def crop_margin(img: Image.Image) -> Image.Image:
+ data = np.array(img.convert("L"))
+ data = data.astype(np.uint8)
+ max_val = data.max()
+ min_val = data.min()
+ if max_val == min_val:
+ return img
+ data = (data - min_val) / (max_val - min_val) * 255
+ gray = 255 * (data < 200).astype(np.uint8)
+
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
+ return img.crop((a, b, w + a, h + b))
+
+ def prepare_input(self, img: Image.Image, random_padding: bool = False):
+ """
+ Convert PIL Image to tensor according to specified input_size after following steps below:
+ - resize
+ - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
+ - pad
+ """
+ if img is None:
+ return
+ # crop margins
+ try:
+ img = self.crop_margin(img.convert("RGB"))
+ except OSError:
+ # might throw an error for broken files
+ return
+
+ if img.height == 0 or img.width == 0:
+ return
+
+ img = resize(img, min(self.input_size))
+ img.thumbnail((self.input_size[1], self.input_size[0]))
+ delta_width = self.input_size[1] - img.width
+ delta_height = self.input_size[0] - img.height
+ if random_padding:
+ pad_width = np.random.randint(low=0, high=delta_width + 1)
+ pad_height = np.random.randint(low=0, high=delta_height + 1)
+ else:
+ pad_width = delta_width // 2
+ pad_height = delta_height // 2
+ padding = (
+ pad_width,
+ pad_height,
+ delta_width - pad_width,
+ delta_height - pad_height,
+ )
+ return ImageOps.expand(img, padding)
+
+
+@registry.register_processor("formula_image_train")
+class FormulaImageTrainProcessor(FormulaImageBaseProcessor):
+ def __init__(self, image_size=384):
+ super().__init__(image_size)
+
+ self.transform = alb.Compose(
+ [
+ alb.Compose(
+ [
+ Bitmap(p=0.05),
+ alb.OneOf([Fog(), Frost(), Snow(), Rain(), Shadow()], p=0.2),
+ alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.2),
+ alb.ShiftScaleRotate(shift_limit=0, scale_limit=(-.15, 0), rotate_limit=1, border_mode=0,
+ interpolation=3,
+ value=[255, 255, 255],
+ p=1),
+ alb.GridDistortion(distort_limit=0.1, border_mode=0, interpolation=3, value=[255, 255, 255],
+ p=.5)],
+ p=.15),
+ # alb.InvertImg(p=.15),
+ alb.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
+ alb.GaussNoise(10, p=.2),
+ alb.RandomBrightnessContrast(.05, (-.2, 0), True, p=0.2),
+ alb.ImageCompression(95, p=.3),
+ alb.ToGray(always_apply=True),
+ alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
+ # alb.Sharpen()
+ ToTensorV2(),
+ ]
+ )
+
+ def __call__(self, item):
+ img = self.prepare_input(item, random_padding=True)
+ if img is None:
+ return img
+ return self.transform(image=np.array(img))['image'][:1]
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", [384, 384])
+
+ return cls(
+ image_size=image_size,
+ )
+
+
+@registry.register_processor("formula_image_multi_scale_train")
+class FormulaImageMultiScaleTrainProcessor(FormulaImageTrainProcessor):
+ def __init__(self, all_scales):
+ for i, scales in enumerate(all_scales):
+ all_scales[i] = [int(_) for _ in scales]
+ super(FormulaImageMultiScaleTrainProcessor, self).__init__(all_scales[0])
+ self.all_scales = all_scales
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ all_scales = cfg.get("all_scales", [[384, 384]])
+ return cls(
+ all_scales=all_scales
+ )
+
+ def reset_scale(self):
+ self.input_size = random.choice(self.all_scales)
+
+
+@registry.register_processor("formula_image_eval")
+class FormulaImageEvalProcessor(FormulaImageBaseProcessor):
+ def __init__(self, image_size):
+ super().__init__(image_size)
+
+ self.transform = alb.Compose(
+ [
+ alb.ToGray(always_apply=True),
+ alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
+ # alb.Sharpen()
+ ToTensorV2(),
+ ]
+ )
+
+ def __call__(self, item):
+ image = self.prepare_input(item)
+ return self.transform(image=np.array(image))['image'][:1]
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", [384, 384])
+
+ return cls(image_size=image_size)
diff --git a/unimernet/processors/formula_processor_helper/__init__.py b/unimernet/processors/formula_processor_helper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c43204a142ca2d57e1d4b5a2f933ac4194b55dde
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..837d17975ab8a3170a2ad42e3ec9df97e7c5b016
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcf65d49a90b75e9663617a33efa0963bbe42748
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae350972ba7f2d656db36bdcf258b53bf6e5e226
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost1.png b/unimernet/processors/formula_processor_helper/frost/frost1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9edf9b6e1a2744d15af615af641f2aa48aa89c2
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/frost/frost1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff9f907860bd7a835d459e32f9d588062b7f61ee267343cc7222b56753a14755
+size 1199930
diff --git a/unimernet/processors/formula_processor_helper/frost/frost2.png b/unimernet/processors/formula_processor_helper/frost/frost2.png
new file mode 100644
index 0000000000000000000000000000000000000000..48f7a861ffa41b6d7496b701fef96d5edf739282
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost2.png differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost3.png b/unimernet/processors/formula_processor_helper/frost/frost3.png
new file mode 100644
index 0000000000000000000000000000000000000000..d47f9d25f41251ee9a66b294c0bbfad6053017c9
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost3.png differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost4.jpg b/unimernet/processors/formula_processor_helper/frost/frost4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f8b0c413176d70150b593e029d84b4a88c21dd4b
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost4.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost5.jpg b/unimernet/processors/formula_processor_helper/frost/frost5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..95dc9056926d8201df760535f9bb9112f012e862
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost5.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost6.jpg b/unimernet/processors/formula_processor_helper/frost/frost6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..14e5d58e762a5d0808df9fa6494fd6d78ee4409b
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost6.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/nougat.py b/unimernet/processors/formula_processor_helper/nougat.py
new file mode 100644
index 0000000000000000000000000000000000000000..c51b99f3233b1b5254e423035b4a2a07816be5a3
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/nougat.py
@@ -0,0 +1,98 @@
+import albumentations as alb
+import numpy as np
+import cv2
+
+
+class Erosion(alb.ImageOnlyTransform):
+ """
+ Apply erosion operation to an image.
+
+ Erosion is a morphological operation that shrinks the white regions in a binary image.
+
+ Args:
+ scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
+ If an integer is provided, a square kernel of that size will be used.
+ If a tuple or list is provided, it should contain two integers representing the minimum
+ and maximum sizes for the erosion kernel.
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+ Returns:
+ numpy.ndarray: The transformed image.
+ """
+
+ def __init__(self, scale, always_apply=False, p=0.5):
+ super().__init__(always_apply=always_apply, p=p)
+ if type(scale) is tuple or type(scale) is list:
+ assert len(scale) == 2
+ self.scale = scale
+ else:
+ self.scale = (scale, scale)
+
+ def apply(self, img, **params):
+ kernel = cv2.getStructuringElement(
+ cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
+ )
+ img = cv2.erode(img, kernel, iterations=1)
+ return img
+
+
+class Dilation(alb.ImageOnlyTransform):
+ """
+ Apply dilation operation to an image.
+
+ Dilation is a morphological operation that expands the white regions in a binary image.
+
+ Args:
+ scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
+ If an integer is provided, a square kernel of that size will be used.
+ If a tuple or list is provided, it should contain two integers representing the minimum
+ and maximum sizes for the dilation kernel.
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+ Returns:
+ numpy.ndarray: The transformed image.
+ """
+
+ def __init__(self, scale, always_apply=False, p=0.5):
+ super().__init__(always_apply=always_apply, p=p)
+ if type(scale) is tuple or type(scale) is list:
+ assert len(scale) == 2
+ self.scale = scale
+ else:
+ self.scale = (scale, scale)
+
+ def apply(self, img, **params):
+ kernel = cv2.getStructuringElement(
+ cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
+ )
+ img = cv2.dilate(img, kernel, iterations=1)
+ return img
+
+
+class Bitmap(alb.ImageOnlyTransform):
+ """
+ Apply a bitmap-style transformation to an image.
+
+ This transformation replaces all pixel values below a certain threshold with a specified value.
+
+ Args:
+ value (int, optional): The value to replace pixels below the threshold with. Default is 0.
+ lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
+ always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+ Returns:
+ numpy.ndarray: The transformed image.
+ """
+
+ def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
+ super().__init__(always_apply=always_apply, p=p)
+ self.lower = lower
+ self.value = value
+
+ def apply(self, img, **params):
+ img = img.copy()
+ img[img < self.lower] = self.value
+ return img
diff --git a/unimernet/processors/formula_processor_helper/ops.py b/unimernet/processors/formula_processor_helper/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b84016300e9237b87d2ce1d82e3f777f1599fe23
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/ops.py
@@ -0,0 +1,88 @@
+"""
+Common image operations
+
+Reference: https://github.com/hendrycks/robustness
+Hacked together for STR by: Rowel Atienza
+"""
+
+import cv2
+import numpy as np
+from scipy.ndimage import zoom as scizoom
+
+
+def clipped_zoom(img, zoom_factor):
+ h = img.shape[1]
+ # ceil crop height(= crop width)
+ ch = int(np.ceil(h / float(zoom_factor)))
+
+ top = (h - ch) // 2
+ img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
+ # trim off any extra pixels
+ trim_top = (img.shape[0] - h) // 2
+
+ return img[trim_top:trim_top + h, trim_top:trim_top + h]
+
+
+def disk(radius, alias_blur=0.1, dtype=np.float32):
+ if radius <= 8:
+ coords = np.arange(-8, 8 + 1)
+ ksize = (3, 3)
+ else:
+ coords = np.arange(-radius, radius + 1)
+ ksize = (5, 5)
+ x, y = np.meshgrid(coords, coords)
+ aliased_disk = np.asarray((x ** 2 + y ** 2) <= radius ** 2, dtype=dtype)
+ aliased_disk /= np.sum(aliased_disk)
+
+ # supersample disk to antialias
+ return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
+
+
+# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
+def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
+ """
+ Generate a heightmap using diamond-square algorithm.
+ Return square 2d array, side length 'mapsize', of floats in range 0-255.
+ 'mapsize' must be a power of two.
+ """
+ assert (mapsize & (mapsize - 1) == 0)
+ maparray = np.empty((mapsize, mapsize), dtype=np.float_)
+ maparray[0, 0] = 0
+ stepsize = mapsize
+ wibble = 100
+ if rng is None:
+ rng = np.random.default_rng()
+
+ def wibbledmean(array):
+ return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
+
+ def fillsquares():
+ """For each square of points stepsize apart,
+ calculate middle value as mean of points + wibble"""
+ cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
+ squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
+ squareaccum += np.roll(squareaccum, shift=-1, axis=1)
+ maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
+
+ def filldiamonds():
+ """For each diamond of points stepsize apart,
+ calculate middle value as mean of points + wibble"""
+ drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
+ ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
+ ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
+ lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
+ ltsum = ldrsum + lulsum
+ maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
+ tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
+ tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
+ ttsum = tdrsum + tulsum
+ maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
+
+ while stepsize >= 2:
+ fillsquares()
+ filldiamonds()
+ stepsize //= 2
+ wibble /= wibbledecay
+
+ maparray -= maparray.min()
+ return maparray / maparray.max()
\ No newline at end of file
diff --git a/unimernet/processors/formula_processor_helper/weather.py b/unimernet/processors/formula_processor_helper/weather.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa5233958ad3f076df2d07b2c65b6e1f5cd29d7a
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/weather.py
@@ -0,0 +1,245 @@
+import math
+from io import BytesIO
+
+import cv2
+import numpy as np
+from PIL import Image, ImageOps, ImageDraw
+from pkg_resources import resource_filename
+from wand.image import Image as WandImage
+import albumentations as alb
+
+from .ops import plasma_fractal
+
+
+class Fog(alb.ImageOnlyTransform):
+ def __init__(self, mag=-1, always_apply=False, p=1.):
+ super().__init__(always_apply=always_apply, p=p)
+ self.rng = np.random.default_rng()
+ self.mag = mag
+
+ def apply(self, img, **params):
+ img = Image.fromarray(img.astype(np.uint8))
+ w, h = img.size
+ c = [(1.5, 2), (2., 2), (2.5, 1.7)]
+ if self.mag < 0 or self.mag >= len(c):
+ index = self.rng.integers(0, len(c))
+ else:
+ index = self.mag
+ c = c[index]
+
+ n_channels = len(img.getbands())
+ isgray = n_channels == 1
+
+ img = np.asarray(img) / 255.
+ max_val = img.max()
+ # Make sure fog image is at least twice the size of the input image
+ max_size = 2 ** math.ceil(math.log2(max(w, h)) + 1)
+ fog = c[0] * plasma_fractal(mapsize=max_size, wibbledecay=c[1], rng=self.rng)[:h, :w][..., np.newaxis]
+ # x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis]
+ # return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
+ if isgray:
+ fog = np.squeeze(fog)
+ else:
+ fog = np.repeat(fog, 3, axis=2)
+
+ img += fog
+ img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
+ return img.astype(np.uint8)
+
+
+class Frost(alb.ImageOnlyTransform):
+ def __init__(self, mag=-1, always_apply=False, p=1.):
+ super().__init__(always_apply=always_apply, p=p)
+ self.rng = np.random.default_rng()
+ self.mag = mag
+
+ def apply(self, img, **params):
+ img = Image.fromarray(img.astype(np.uint8))
+ w, h = img.size
+ c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
+ if self.mag < 0 or self.mag >= len(c):
+ index = self.rng.integers(0, len(c))
+ else:
+ index = self.mag
+ c = c[index]
+
+ filename = [resource_filename(__name__, 'frost/frost1.png'),
+ resource_filename(__name__, 'frost/frost2.png'),
+ resource_filename(__name__, 'frost/frost3.png'),
+ resource_filename(__name__, 'frost/frost4.jpg'),
+ resource_filename(__name__, 'frost/frost5.jpg'),
+ resource_filename(__name__, 'frost/frost6.jpg')]
+ index = self.rng.integers(0, len(filename))
+ filename = filename[index]
+ # Some images have transparency. Remove alpha channel.
+ frost = Image.open(filename).convert('RGB')
+
+ # Resize the frost image to match the input image's dimensions
+ f_w, f_h = frost.size
+ if w / h > f_w / f_h:
+ f_h = round(f_h * w / f_w)
+ f_w = w
+ else:
+ f_w = round(f_w * h / f_h)
+ f_h = h
+ frost = np.asarray(frost.resize((f_w, f_h)))
+
+ # randomly crop
+ y_start, x_start = self.rng.integers(0, f_h - h + 1), self.rng.integers(0, f_w - w + 1)
+ frost = frost[y_start:y_start + h, x_start:x_start + w]
+
+ n_channels = len(img.getbands())
+ isgray = n_channels == 1
+
+ img = np.asarray(img)
+
+ if isgray:
+ img = np.expand_dims(img, axis=2)
+ img = np.repeat(img, 3, axis=2)
+
+ img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
+ img = img.astype(np.uint8)
+ if isgray:
+ img = np.squeeze(img)
+ return img
+
+
+class Snow(alb.ImageOnlyTransform):
+ def __init__(self, mag=-1, always_apply=False, p=1.):
+ super().__init__(always_apply=always_apply, p=p)
+ self.rng = np.random.default_rng()
+ self.mag = mag
+
+ def apply(self, img, **params):
+ img = Image.fromarray(img.astype(np.uint8))
+ w, h = img.size
+ c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
+ (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
+ (0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
+ if self.mag < 0 or self.mag >= len(c):
+ index = self.rng.integers(0, len(c))
+ else:
+ index = self.mag
+ c = c[index]
+
+ n_channels = len(img.getbands())
+ isgray = n_channels == 1
+
+ img = np.asarray(img, dtype=np.float32) / 255.
+ if isgray:
+ img = np.expand_dims(img, axis=2)
+ img = np.repeat(img, 3, axis=2)
+
+ snow_layer = self.rng.normal(size=img.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
+
+ # snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
+ snow_layer[snow_layer < c[3]] = 0
+
+ snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
+ output = BytesIO()
+ snow_layer.save(output, format='PNG')
+ snow_layer = WandImage(blob=output.getvalue())
+
+ snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=self.rng.uniform(-135, -45))
+
+ snow_layer = cv2.imdecode(np.frombuffer(snow_layer.make_blob(), np.uint8),
+ cv2.IMREAD_UNCHANGED) / 255.
+
+ # snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB)
+
+ snow_layer = snow_layer[..., np.newaxis]
+
+ img = c[6] * img
+ gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) * 1.5 + 0.5)
+ img += gray_img
+ img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
+ img = img.astype(np.uint8)
+ if isgray:
+ img = np.squeeze(img)
+ return img
+
+
+class Rain(alb.ImageOnlyTransform):
+ def __init__(self, mag=-1, always_apply=False, p=1.):
+ super().__init__(always_apply=always_apply, p=p)
+ self.rng = np.random.default_rng()
+ self.mag = mag
+
+ def apply(self, img, **params):
+ img = Image.fromarray(img.astype(np.uint8))
+ img = img.copy()
+ w, h = img.size
+ n_channels = len(img.getbands())
+ isgray = n_channels == 1
+ line_width = self.rng.integers(1, 2)
+
+ c = [50, 70, 90]
+ if self.mag < 0 or self.mag >= len(c):
+ index = 0
+ else:
+ index = self.mag
+ c = c[index]
+
+ n_rains = self.rng.integers(c, c + 20)
+ slant = self.rng.integers(-60, 60)
+ fillcolor = 200 if isgray else (200, 200, 200)
+
+ draw = ImageDraw.Draw(img)
+ max_length = min(w, h, 10)
+ for i in range(1, n_rains):
+ length = self.rng.integers(5, max_length)
+ x1 = self.rng.integers(0, w - length)
+ y1 = self.rng.integers(0, h - length)
+ x2 = x1 + length * math.sin(slant * math.pi / 180.)
+ y2 = y1 + length * math.cos(slant * math.pi / 180.)
+ x2 = int(x2)
+ y2 = int(y2)
+ draw.line([(x1, y1), (x2, y2)], width=line_width, fill=fillcolor)
+ img = np.asarray(img).astype(np.uint8)
+ return img
+
+
+class Shadow(alb.ImageOnlyTransform):
+ def __init__(self, mag=-1, always_apply=False, p=1.):
+ super().__init__(always_apply=always_apply, p=p)
+ self.rng = np.random.default_rng()
+ self.mag = mag
+
+ def apply(self, img, **params):
+ img = Image.fromarray(img.astype(np.uint8))
+ # img = img.copy()
+ w, h = img.size
+ n_channels = len(img.getbands())
+ isgray = n_channels == 1
+
+ c = [64, 96, 128]
+ if self.mag < 0 or self.mag >= len(c):
+ index = 0
+ else:
+ index = self.mag
+ c = c[index]
+
+ img = img.convert('RGBA')
+ overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
+ draw = ImageDraw.Draw(overlay)
+ transparency = self.rng.integers(c, c + 32)
+ x1 = self.rng.integers(0, w // 2)
+ y1 = 0
+
+ x2 = self.rng.integers(w // 2, w)
+ y2 = 0
+
+ x3 = self.rng.integers(w // 2, w)
+ y3 = h - 1
+
+ x4 = self.rng.integers(0, w // 2)
+ y4 = h - 1
+
+ draw.polygon([(x1, y1), (x2, y2), (x3, y3), (x4, y4)], fill=(0, 0, 0, transparency))
+
+ img = Image.alpha_composite(img, overlay)
+ img = img.convert("RGB")
+ if isgray:
+ img = ImageOps.grayscale(img)
+ img = np.asarray(img).astype(np.uint8)
+ return img
diff --git a/unimernet/processors/randaugment.py b/unimernet/processors/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c6a9e6d62f74358f490d19546c9829b3ac6aaef
--- /dev/null
+++ b/unimernet/processors/randaugment.py
@@ -0,0 +1,398 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import cv2
+import numpy as np
+
+import torch
+
+
+## aug functions
+def identity_func(img):
+ return img
+
+
+def autocontrast_func(img, cutoff=0):
+ """
+ same output as PIL.ImageOps.autocontrast
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ n = ch.size
+ cut = cutoff * n // 100
+ if cut == 0:
+ high, low = ch.max(), ch.min()
+ else:
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ low = np.argwhere(np.cumsum(hist) > cut)
+ low = 0 if low.shape[0] == 0 else low[0]
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+ if high <= low:
+ table = np.arange(n_bins)
+ else:
+ scale = (n_bins - 1) / (high - low)
+ offset = -low * scale
+ table = np.arange(n_bins) * scale + offset
+ table[table < 0] = 0
+ table[table > n_bins - 1] = n_bins - 1
+ table = table.clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def equalize_func(img):
+ """
+ same output as PIL.ImageOps.equalize
+ PIL's implementation is different from cv2.equalize
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ non_zero_hist = hist[hist != 0].reshape(-1)
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+ if step == 0:
+ return ch
+ n = np.empty_like(hist)
+ n[0] = step // 2
+ n[1:] = hist[:-1]
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+ """
+ like PIL, rotate by degree, not radians
+ """
+ H, W = img.shape[0], img.shape[1]
+ center = W / 2, H / 2
+ M = cv2.getRotationMatrix2D(center, degree, 1)
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+ return out
+
+
+def solarize_func(img, thresh=128):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
+ table = table.clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def color_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Color
+ """
+ ## implementation according to PIL definition, quite slow
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+ # out = blend(degenerate, img, factor)
+ # M = (
+ # np.eye(3) * factor
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+ # )[np.newaxis, np.newaxis, :]
+ M = np.float32(
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+ return out
+
+
+def contrast_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+ table = (
+ np.array([(el - mean) * factor + mean for el in range(256)])
+ .clip(0, 255)
+ .astype(np.uint8)
+ )
+ out = table[img]
+ return out
+
+
+def brightness_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def sharpness_func(img, factor):
+ """
+ The differences the this result and PIL are all on the 4 boundaries, the center
+ areas are same
+ """
+ kernel = np.ones((3, 3), dtype=np.float32)
+ kernel[1][1] = 5
+ kernel /= 13
+ degenerate = cv2.filter2D(img, -1, kernel)
+ if factor == 0.0:
+ out = degenerate
+ elif factor == 1.0:
+ out = img
+ else:
+ out = img.astype(np.float32)
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+ out = out.astype(np.uint8)
+ return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def posterize_func(img, bits):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+ return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+ replace = np.array(replace, dtype=np.uint8)
+ H, W = img.shape[0], img.shape[1]
+ rh, rw = np.random.random(2)
+ pad_size = pad_size // 2
+ ch, cw = int(rh * H), int(rw * W)
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+ out = img.copy()
+ out[x1:x2, y1:y2, :] = replace
+ return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+
+ return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 0.3
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * float(translate_const)
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * cutout_const)
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 256)
+ return (level,)
+
+ return level_to_args
+
+
+def none_level_to_args(level):
+ return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 4)
+ return (level,)
+
+ return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 30
+ if np.random.random() < 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+func_dict = {
+ "Identity": identity_func,
+ "AutoContrast": autocontrast_func,
+ "Equalize": equalize_func,
+ "Rotate": rotate_func,
+ "Solarize": solarize_func,
+ "Color": color_func,
+ "Contrast": contrast_func,
+ "Brightness": brightness_func,
+ "Sharpness": sharpness_func,
+ "ShearX": shear_x_func,
+ "TranslateX": translate_x_func,
+ "TranslateY": translate_y_func,
+ "Posterize": posterize_func,
+ "ShearY": shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+ "Identity": none_level_to_args,
+ "AutoContrast": none_level_to_args,
+ "Equalize": none_level_to_args,
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
+ "Color": enhance_level_to_args(MAX_LEVEL),
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+ self.N = N
+ self.M = M
+ self.isPIL = isPIL
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N)
+ return [(op, 0.5, self.M) for op in sampled_ops]
+
+ def __call__(self, img):
+ if self.isPIL:
+ img = np.array(img)
+ ops = self.get_random_ops()
+ for name, prob, level in ops:
+ if np.random.random() > prob:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return img
+
+
+class VideoRandomAugment(object):
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
+ self.N = N
+ self.M = M
+ self.p = p
+ self.tensor_in_tensor_out = tensor_in_tensor_out
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
+ return [(op, self.M) for op in sampled_ops]
+
+ def __call__(self, frames):
+ assert (
+ frames.shape[-1] == 3
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
+
+ if self.tensor_in_tensor_out:
+ frames = frames.numpy().astype(np.uint8)
+
+ num_frames = frames.shape[0]
+
+ ops = num_frames * [self.get_random_ops()]
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
+
+ frames = torch.stack(
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
+ ).float()
+
+ return frames
+
+ def _aug(self, img, ops, apply_or_not):
+ for i, (name, level) in enumerate(ops):
+ if not apply_or_not[i]:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return torch.from_numpy(img)
+
+
+if __name__ == "__main__":
+ a = RandomAugment()
+ img = np.random.randn(32, 32, 3)
+ a(img)
diff --git a/unimernet/runners/__init__.py b/unimernet/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2960f292b80ba543877d92c96d94eb3cddaed22
--- /dev/null
+++ b/unimernet/runners/__init__.py
@@ -0,0 +1,11 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.runners.runner_base import RunnerBase
+from unimernet.runners.runner_iter import RunnerIter
+
+__all__ = ["RunnerBase", "RunnerIter"]
diff --git a/unimernet/runners/runner_base.py b/unimernet/runners/runner_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c5eccf38efa6e745dc06ee98bb531daa17affd3
--- /dev/null
+++ b/unimernet/runners/runner_base.py
@@ -0,0 +1,670 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import json
+import logging
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from unimernet.common.dist_utils import (
+ download_cached_file,
+ get_rank,
+ get_world_size,
+ is_main_process,
+ main_process,
+)
+from unimernet.common.registry import registry
+from unimernet.common.utils import is_url
+from unimernet.datasets.data_utils import reorg_datasets_by_split, concat_datasets
+from unimernet.datasets.datasets.dataloader_utils import (
+ IterLoader,
+ MultiIterLoader,
+ ConcatLoader,
+ PrefetchLoader,
+)
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data.dataset import ChainDataset
+
+
+@registry.register_runner("runner_base")
+class RunnerBase:
+ """
+ A runner class to train and evaluate a model given a task and datasets.
+
+ The runner uses pytorch distributed data parallel by default. Future release
+ will support other distributed frameworks.
+ """
+
+ def __init__(self, cfg, task, model, datasets, job_id):
+ self.config = cfg
+ self.job_id = job_id
+
+ self.task = task
+ self.datasets = datasets
+
+ self._model = model
+
+ self._wrapped_model = None
+ self._device = None
+ self._optimizer = None
+ self._scaler = None
+ self._dataloaders = None
+ self._lr_sched = None
+
+ self.start_epoch = 0
+
+ # self.setup_seeds()
+ self.setup_output_dir()
+
+ @property
+ def device(self):
+ if self._device is None:
+ self._device = torch.device(self.config.run_cfg.device)
+
+ return self._device
+
+ @property
+ def milestone(self):
+ return self.config.run_cfg.get("milestone", None)
+
+ @property
+ def use_distributed(self):
+ return self.config.run_cfg.distributed
+
+ @property
+ def model(self):
+ """
+ A property to get the DDP-wrapped model on the device.
+ """
+ # move model to device
+ if self._model.device != self.device:
+ self._model = self._model.to(self.device)
+
+ # distributed training wrapper
+ if self.use_distributed:
+ if self._wrapped_model is None:
+ self._wrapped_model = DDP(
+ self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=False
+ )
+ else:
+ self._wrapped_model = self._model
+
+ return self._wrapped_model
+
+ @property
+ def optimizer(self):
+ # TODO make optimizer class and configurations
+ if self._optimizer is None:
+ num_parameters = 0
+ p_wd, p_non_wd = [], []
+ for n, p in self.model.named_parameters():
+ if not p.requires_grad:
+ continue # frozen weights
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
+ p_non_wd.append(p)
+ else:
+ p_wd.append(p)
+ num_parameters += p.data.nelement()
+ logging.info("number of trainable parameters: %d" % num_parameters)
+ optim_params = [
+ {
+ "params": p_wd,
+ "weight_decay": float(self.config.run_cfg.weight_decay),
+ },
+ {"params": p_non_wd, "weight_decay": 0},
+ ]
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
+ self._optimizer = torch.optim.AdamW(
+ optim_params,
+ lr=float(self.config.run_cfg.init_lr),
+ weight_decay=float(self.config.run_cfg.weight_decay),
+ betas=(0.9, beta2),
+ )
+
+ return self._optimizer
+
+ @property
+ def scaler(self):
+ amp = self.config.run_cfg.get("amp", False)
+
+ if amp:
+ if self._scaler is None:
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ return self._scaler
+
+ @property
+ def lr_scheduler(self):
+ """
+ A property to get and create learning rate scheduler by split just in need.
+ """
+ if self._lr_sched is None:
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
+
+ # max_epoch = self.config.run_cfg.max_epoch
+ max_epoch = self.max_epoch
+ # min_lr = self.config.run_cfg.min_lr
+ min_lr = self.min_lr
+ # init_lr = self.config.run_cfg.init_lr
+ init_lr = self.init_lr
+
+ # optional parameters
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
+ iters_per_epoch = self.config.run_cfg.get("iters_per_inner_epoch", len(self.train_loader))
+
+ self._lr_sched = lr_sched_cls(
+ optimizer=self.optimizer,
+ max_epoch=max_epoch,
+ min_lr=min_lr,
+ init_lr=init_lr,
+ decay_rate=decay_rate,
+ warmup_start_lr=warmup_start_lr,
+ warmup_steps=warmup_steps,
+ iters_per_epoch=iters_per_epoch,
+ )
+
+ return self._lr_sched
+
+ @property
+ def dataloaders(self) -> dict:
+ """
+ A property to get and create dataloaders by split just in need.
+
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
+ each dataset by ratios during training.
+
+ Currently do not support multiple datasets for validation and test.
+
+ Returns:
+ dict: {split_name: (tuples of) dataloader}
+ """
+ if self._dataloaders is None:
+ # reoganize datasets by split and concatenate/chain if necessary
+
+ datasets = reorg_datasets_by_split(self.datasets)
+ self.datasets = concat_datasets(datasets)
+
+ self.datasets = {
+ k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
+ }
+
+ # print dataset statistics after concatenation/chaining
+ for split_name in self.datasets:
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
+ self.datasets[split_name], list
+ ):
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
+ num_records = sum(
+ [
+ len(d)
+ if not type(d) in [wds.DataPipeline, ChainDataset]
+ else 0
+ for d in self.datasets[split_name]
+ ]
+ )
+
+ else:
+ if hasattr(self.datasets[split_name], "__len__"):
+ # a single map-style dataset
+ num_records = len(self.datasets[split_name])
+ else:
+ # a single wds.DataPipeline
+ num_records = -1
+ logging.info(
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
+ )
+
+ if num_records >= 0:
+ logging.info(
+ "Loaded {} records for {} split from the dataset.".format(
+ num_records, split_name
+ )
+ )
+
+ # create dataloaders
+ split_names = sorted(self.datasets.keys())
+
+ datasets = [self.datasets[split] for split in split_names]
+ is_trains = [split in self.train_splits for split in split_names]
+
+ batch_sizes = [
+ self.config.run_cfg.batch_size_train
+ if split == "train"
+ else self.config.run_cfg.batch_size_eval
+ for split in split_names
+ ]
+
+ collate_fns = []
+ for dataset in datasets:
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
+ else:
+ collate_fns.append(getattr(dataset, "collater", None))
+
+ dataloaders = self.create_loaders(
+ datasets=datasets,
+ num_workers=self.config.run_cfg.num_workers,
+ batch_sizes=batch_sizes,
+ is_trains=is_trains,
+ collate_fns=collate_fns,
+ # concat=True
+ )
+
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+ return self._dataloaders
+
+ @property
+ def cuda_enabled(self):
+ return self.device.type == "cuda"
+
+ @property
+ def max_epoch(self):
+ return int(self.config.run_cfg.max_epoch)
+
+ @property
+ def log_freq(self):
+ log_freq = self.config.run_cfg.get("log_freq", 50)
+ return int(log_freq)
+
+ @property
+ def init_lr(self):
+ return float(self.config.run_cfg.init_lr)
+
+ @property
+ def min_lr(self):
+ return float(self.config.run_cfg.min_lr)
+
+ @property
+ def accum_grad_iters(self):
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
+
+ @property
+ def valid_splits(self):
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
+
+ if len(valid_splits) == 0:
+ logging.info("No validation splits found.")
+
+ return valid_splits
+
+ @property
+ def test_splits(self):
+ test_splits = self.config.run_cfg.get("test_splits", [])
+
+ return test_splits
+
+ @property
+ def train_splits(self):
+ train_splits = self.config.run_cfg.get("train_splits", [])
+
+ if len(train_splits) == 0:
+ logging.info("Empty train splits.")
+
+ return train_splits
+
+ @property
+ def evaluate_only(self):
+ """
+ Set to True to skip training.
+ """
+ return self.config.run_cfg.evaluate
+
+ @property
+ def use_dist_eval_sampler(self):
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
+
+ @property
+ def resume_ckpt_path(self):
+ return self.config.run_cfg.get("resume_ckpt_path", None)
+
+ @property
+ def train_loader(self):
+ train_dataloader = self.dataloaders["train"]
+
+ return train_dataloader
+
+ def setup_output_dir(self):
+ lib_root = Path(registry.get_path("library_root"))
+
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
+ result_dir = output_dir / "result"
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ registry.register_path("result_dir", str(result_dir))
+ registry.register_path("output_dir", str(output_dir))
+
+ self.result_dir = result_dir
+ self.output_dir = output_dir
+
+ def train(self):
+ start_time = time.time()
+ best_agg_metric = 0
+ best_epoch = 0
+
+ self.log_config()
+
+ # resume from checkpoint if specified
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
+ self._load_checkpoint(self.resume_ckpt_path)
+
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
+ # training phase
+ if not self.evaluate_only:
+ logging.info("Start training")
+ train_stats = self.train_epoch(cur_epoch)
+ self.log_stats(split_name="train", stats=train_stats)
+
+ # evaluation phase
+ if len(self.valid_splits) > 0:
+ for split_name in self.valid_splits:
+ logging.info("Evaluating on {}.".format(split_name))
+
+ val_log = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch
+ )
+ if val_log is not None:
+ if is_main_process():
+ assert (
+ "agg_metrics" in val_log
+ ), "No agg_metrics found in validation log."
+
+ agg_metrics = val_log["agg_metrics"]
+ if agg_metrics > best_agg_metric and split_name == "eval":
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
+
+ self._save_checkpoint(cur_epoch, is_best=True)
+
+ val_log.update({"best_epoch": best_epoch})
+ self.log_stats(val_log, split_name)
+
+ if self.evaluate_only:
+ break
+ if self.milestone and cur_epoch + 1 in self.milestone:
+ self._save_checkpoint(cur_epoch)
+ self._save_checkpoint(cur_epoch, latest=True)
+ dist.barrier()
+
+ # testing phase
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Training time {}".format(total_time_str))
+
+ def evaluate(self, cur_epoch="best", skip_reload=False):
+ test_logs = dict()
+
+ if len(self.test_splits) > 0:
+ for split_name in self.test_splits:
+ test_logs[split_name] = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
+ )
+
+ return test_logs
+
+ def train_epoch(self, epoch):
+ # train
+ self.model.train()
+
+ return self.task.train_epoch(
+ epoch=epoch,
+ model=self.model,
+ data_loader=self.train_loader,
+ optimizer=self.optimizer,
+ scaler=self.scaler,
+ lr_scheduler=self.lr_scheduler,
+ cuda_enabled=self.cuda_enabled,
+ log_freq=self.log_freq,
+ accum_grad_iters=self.accum_grad_iters,
+ )
+
+ @torch.no_grad()
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
+ """
+ Evaluate the model on a given split.
+
+ Args:
+ split_name (str): name of the split to evaluate on.
+ cur_epoch (int): current epoch.
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
+ During training, we will reload the best checkpoint for validation.
+ During testing, we will use provided weights and skip reloading the best checkpoint .
+ """
+ data_loader = self.dataloaders.get(split_name, None)
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
+
+ # TODO In validation, you need to compute loss as well as metrics
+ # TODO consider moving to model.before_evaluation()
+ model = self.unwrap_dist_model(self.model)
+ if not skip_reload and cur_epoch == "best":
+ model = self._reload_best_model(model)
+ model.eval()
+
+ self.task.before_evaluation(
+ model=model,
+ dataset=self.datasets[split_name],
+ )
+ results = self.task.evaluation(model, data_loader)
+
+ if results is not None:
+ return self.task.after_evaluation(
+ val_result=results,
+ split_name=split_name,
+ epoch=cur_epoch,
+ )
+
+ def unwrap_dist_model(self, model):
+ if self.use_distributed:
+ return model.module
+ else:
+ return model
+
+ def create_loaders(
+ self,
+ datasets,
+ num_workers,
+ batch_sizes,
+ is_trains,
+ collate_fns,
+ concat=False
+ ):
+ """
+ Create dataloaders for training and validation.
+ """
+
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
+ # create a single dataloader for each split
+ if isinstance(dataset, ChainDataset) or isinstance(
+ dataset, wds.DataPipeline
+ ):
+ # wds.WebdDataset instance are chained together
+ # webdataset.DataPipeline has its own sampler and collate_fn
+ loader = iter(
+ DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+ )
+ else:
+ # map-style dataset are concatenated together
+ # setup distributed sampler
+ if self.use_distributed:
+ sampler = DistributedSampler(
+ dataset,
+ shuffle=is_train,
+ num_replicas=get_world_size(),
+ rank=get_rank(),
+ )
+ if not self.use_dist_eval_sampler:
+ # e.g. retrieval evaluation
+ sampler = sampler if is_train else None
+ else:
+ sampler = None
+
+ loader = DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ sampler=sampler,
+ shuffle=sampler is None and is_train,
+ collate_fn=collate_fn,
+ drop_last=True if is_train else False,
+ )
+ loader = PrefetchLoader(loader)
+
+ if is_train:
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
+
+ return loader
+
+ loaders = []
+
+ for dataset, bsz, is_train, collate_fn in zip(
+ datasets, batch_sizes, is_trains, collate_fns
+ ):
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
+ if not concat:
+ sample_ratios = [d.sample_ratio for d in dataset]
+ loader = MultiIterLoader(
+ loaders=[
+ _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
+ for i, d in enumerate(dataset)
+ ],
+ ratios=sample_ratios
+ )
+ else:
+ loader = ConcatLoader(
+ loaders=[
+ _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
+ for i, d in enumerate(dataset)
+ ]
+ )
+
+ else:
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
+
+ loaders.append(loader)
+
+ return loaders
+
+ @main_process
+ def _save_checkpoint(self, cur_epoch, is_best=False, latest=False):
+ """
+ Save the checkpoint at the current epoch.
+ """
+ assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
+ model_no_ddp = self.unwrap_dist_model(self.model)
+ param_grad_dic = {
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
+ }
+ state_dict = model_no_ddp.state_dict()
+ for k in list(state_dict.keys()):
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
+ # delete parameters that do not require gradient
+ del state_dict[k]
+ save_obj = {
+ "model": state_dict,
+ "optimizer": self.optimizer.state_dict(),
+ "config": self.config.to_dict(),
+ "scaler": self.scaler.state_dict() if self.scaler else None,
+ "epoch": cur_epoch,
+ }
+ if is_best:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("best"),
+ )
+ elif latest:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("latest"),
+ )
+ else:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format(cur_epoch+1),
+ )
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch+1, save_to))
+ torch.save(save_obj, save_to)
+
+ def _reload_best_model(self, model):
+ """
+ Load the best checkpoint for evaluation.
+ """
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
+
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ try:
+ model.load_state_dict(checkpoint["model"])
+ except RuntimeError as e:
+ logging.warning(
+ """
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
+ Trying to load the model with strict=False.
+ """
+ )
+ model.load_state_dict(checkpoint["model"], strict=False)
+ return model
+
+ def _load_checkpoint(self, url_or_filename):
+ """
+ Resume from a checkpoint.
+ """
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location=self.device)
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location=self.device)
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if self.scaler and "scaler" in checkpoint:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ self.start_epoch = checkpoint["epoch"]
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+ @main_process
+ def log_stats(self, stats, split_name):
+ if isinstance(stats, dict):
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ elif isinstance(stats, list):
+ pass
+
+ @main_process
+ def log_config(self):
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
diff --git a/unimernet/runners/runner_iter.py b/unimernet/runners/runner_iter.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b0a87ccb30c9c9b7993187df302525f866cb48
--- /dev/null
+++ b/unimernet/runners/runner_iter.py
@@ -0,0 +1,309 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import os
+import time
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from unimernet.common.dist_utils import download_cached_file, is_main_process, main_process
+from unimernet.common.registry import registry
+from unimernet.common.utils import is_url
+from unimernet.datasets.data_utils import reorg_datasets_by_split
+from unimernet.runners.runner_base import RunnerBase
+from torch.utils.data.dataset import ChainDataset
+
+
+@registry.register_runner("runner_iter")
+class RunnerIter(RunnerBase):
+ """
+ Run training based on the number of iterations. This is common when
+ the training dataset size is large. Underhood logic is similar to
+ epoch-based training by considering every #iters_per_inner_epoch as an
+ inner epoch.
+
+ In iter-based runner, after every #iters_per_inner_epoch steps, we
+
+ 1) do a validation epoch;
+ 2) schedule the learning rate;
+ 3) save the checkpoint.
+
+ We refer every #iters_per_inner_epoch steps as an inner epoch.
+ """
+
+ def __init__(self, cfg, task, model, datasets, job_id):
+ super().__init__(cfg, task, model, datasets, job_id)
+
+ self.start_iters = 0
+
+ self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
+ assert self.max_iters > 0, "max_iters must be greater than 0."
+
+ self.iters_per_inner_epoch = int(
+ self.config.run_cfg.get("iters_per_inner_epoch", -1)
+ )
+ assert (
+ self.iters_per_inner_epoch > 0
+ ), "iters_per_inner_epoch must be greater than 0."
+
+ @property
+ def max_epoch(self):
+ return int(self.max_iters / self.iters_per_inner_epoch)
+
+ @property
+ def cur_epoch(self):
+ try:
+ return self.train_loader.epoch
+ except AttributeError:
+ # pipeline data (e.g. LAION) is streaming, have no concept of epoch
+ return 0
+
+ def _progress(self, cur_iters):
+ return "{}_iters={}".format(self.cur_epoch, cur_iters)
+
+ def train(self):
+ start_time = time.time()
+ best_agg_metric = 0
+ best_iters = 0
+
+ self.log_config()
+
+ # resume from checkpoint if specified
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
+ self._load_checkpoint(self.resume_ckpt_path)
+ cur_epoch = 0
+ for start_iters in range(
+ self.start_iters, self.max_iters, self.iters_per_inner_epoch
+ ):
+ end_iters = start_iters + self.iters_per_inner_epoch
+
+ # training phase
+ if not self.evaluate_only:
+ logging.info(
+ "Start training, max_iters={}, in total {} inner epochs.".format(
+ self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
+ )
+ )
+
+ train_stats = self.train_iters(self.cur_epoch, start_iters)
+ self.log_stats(split_name="train", stats=train_stats)
+
+ # evaluation phase
+ if len(self.valid_splits) > 0:
+ for split_name in self.valid_splits:
+ logging.info("Evaluating on {}.".format(split_name))
+
+ val_log = self.eval_epoch(
+ split_name=split_name, cur_epoch=self._progress(end_iters)
+ )
+ if val_log is not None:
+ if is_main_process():
+ assert (
+ "agg_metrics" in val_log
+ ), "No agg_metrics found in validation log."
+
+ agg_metrics = val_log["agg_metrics"]
+ if agg_metrics > best_agg_metric and split_name == "eval":
+ best_iters, best_agg_metric = end_iters, agg_metrics
+
+ self._save_checkpoint(end_iters, is_best=True)
+ val_log.update({"best_iters": best_iters})
+ self.log_stats(val_log, split_name)
+ # print evaluation metric
+ print(f"bleu:{val_log['bleu']:.6f}, edit_distance:{val_log['edit_distance']:.6f}, token_accuracy:{val_log['token_accuracy']:.6f} ")
+ print("="*80)
+
+ if self.evaluate_only:
+ break
+ if self.milestone and cur_epoch + 1 in self.milestone:
+ self._save_checkpoint(cur_epoch)
+ self._save_checkpoint(end_iters, latest=True)
+ dist.barrier()
+ cur_epoch += 1
+
+ # testing phase
+ self.evaluate(cur_epoch=self.cur_epoch)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Training time {}".format(total_time_str))
+
+ def train_iters(self, epoch, start_iters):
+ # train by iterations
+ self.model.train()
+
+ return self.task.train_iters(
+ epoch=epoch,
+ start_iters=start_iters,
+ iters_per_inner_epoch=self.iters_per_inner_epoch,
+ model=self.model,
+ data_loader=self.train_loader,
+ optimizer=self.optimizer,
+ scaler=self.scaler,
+ lr_scheduler=self.lr_scheduler,
+ cuda_enabled=self.cuda_enabled,
+ log_freq=self.log_freq,
+ accum_grad_iters=self.accum_grad_iters,
+ )
+
+ @main_process
+ def _save_checkpoint(self, cur_iters, is_best=False, latest=False):
+ # only save the params requires gradient
+ assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
+ unwrapped_model = self.unwrap_dist_model(self.model)
+ param_grad_dic = {
+ k: v.requires_grad for (k, v) in unwrapped_model.named_parameters()
+ }
+
+ state_dict = unwrapped_model.state_dict()
+ for k in list(state_dict.keys()):
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
+ del state_dict[k]
+
+ save_obj = {
+ "model": state_dict,
+ "optimizer": self.optimizer.state_dict(),
+ "config": self.config.to_dict(),
+ "scaler": self.scaler.state_dict() if self.scaler else None,
+ "iters": cur_iters,
+ }
+ if is_best:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("best"),
+ )
+ elif latest:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("latest"),
+ )
+ else:
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format(cur_iters),
+ )
+ logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
+ torch.save(save_obj, save_to)
+
+ def _load_checkpoint(self, url_or_filename):
+ """
+ Resume from a checkpoint.
+ """
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location=self.device)
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location=self.device)
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+ self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if self.scaler and "scaler" in checkpoint:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ self.start_iters = checkpoint["iters"] + 1
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+ @property
+ def dataloaders(self) -> dict:
+ """
+ A property to get and create dataloaders by split just in need.
+
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
+ each dataset by ratios during training.
+
+ Currently do not support multiple datasets for validation and test.
+
+ Returns:
+ dict: {split_name: (tuples of) dataloader}
+ """
+ if self._dataloaders is None:
+ # reoganize datasets by split and concatenate/chain if necessary
+
+ self.datasets = reorg_datasets_by_split(self.datasets)
+ # to keep the same structure as return value of concat_datasets
+ self.datasets = {
+ k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
+ }
+
+ # print dataset statistics after concatenation/chaining
+ for split_name in self.datasets:
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
+ self.datasets[split_name], list
+ ):
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
+ num_records = sum(
+ [
+ len(d)
+ if not type(d) in [wds.DataPipeline, ChainDataset]
+ else 0
+ for d in self.datasets[split_name]
+ ]
+ )
+
+ else:
+ try:
+ # a single map-style dataset
+ num_records = len(self.datasets[split_name])
+ except TypeError:
+ # a single wds.DataPipeline or ChainDataset
+ num_records = -1
+ logging.info(
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
+ )
+
+ if num_records >= 0:
+ logging.info(
+ "Loaded {} records for {} split from the dataset.".format(
+ num_records, split_name
+ )
+ )
+
+ # create dataloaders
+ split_names = sorted(self.datasets.keys())
+
+ datasets = [self.datasets[split] for split in split_names]
+ is_trains = [split in self.train_splits for split in split_names]
+
+ batch_sizes = [
+ self.config.run_cfg.batch_size_train
+ if split == "train"
+ else self.config.run_cfg.batch_size_eval
+ for split in split_names
+ ]
+
+ collate_fns = []
+ for dataset in datasets:
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
+ else:
+ collate_fns.append(getattr(dataset, "collater", None))
+
+ dataloaders = self.create_loaders(
+ datasets=datasets,
+ num_workers=self.config.run_cfg.num_workers,
+ batch_sizes=batch_sizes,
+ is_trains=is_trains,
+ collate_fns=collate_fns,
+ )
+
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+ return self._dataloaders
diff --git a/unimernet/tasks/__init__.py b/unimernet/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8198a152b559b5f59c0c9154cfcf7a9df4871
--- /dev/null
+++ b/unimernet/tasks/__init__.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.common.registry import registry
+from unimernet.tasks.base_task import BaseTask
+from unimernet.tasks.unimernet_train import UniMERNet_Train
+
+
+def setup_task(cfg):
+ assert "task" in cfg.run_cfg, "Task name must be provided."
+
+ task_name = cfg.run_cfg.task
+ task = registry.get_task_class(task_name).setup_task(cfg=cfg)
+ assert task is not None, "Task {} not properly registered.".format(task_name)
+
+ return task
+
+
+__all__ = [
+ "BaseTask",
+ "UniMERNet_Train",
+]
diff --git a/unimernet/tasks/__pycache__/__init__.cpython-310.pyc b/unimernet/tasks/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f78e665e826e074c75be2fb52c87bb78f363bdfd
Binary files /dev/null and b/unimernet/tasks/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/tasks/__pycache__/base_task.cpython-310.pyc b/unimernet/tasks/__pycache__/base_task.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9de391d4bef30304f5b6fb95245fcc57e7ff500a
Binary files /dev/null and b/unimernet/tasks/__pycache__/base_task.cpython-310.pyc differ
diff --git a/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc b/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cea4e5569eba74538f4f7da00133c3b74d3e3438
Binary files /dev/null and b/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc differ
diff --git a/unimernet/tasks/base_task.py b/unimernet/tasks/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ea4e27298ac0fcd0640853f726e1748ad2ee87
--- /dev/null
+++ b/unimernet/tasks/base_task.py
@@ -0,0 +1,288 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from unimernet.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
+from unimernet.common.logger import MetricLogger, SmoothedValue
+from unimernet.common.registry import registry
+from unimernet.datasets.data_utils import prepare_sample
+
+
+class BaseTask:
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ self.inst_id_key = "instance_id"
+
+ @classmethod
+ def setup_task(cls, **kwargs):
+ return cls()
+
+ def build_model(self, cfg):
+ model_config = cfg.model_cfg
+
+ model_cls = registry.get_model_class(model_config.arch)
+ return model_cls.from_config(model_config)
+
+ def build_datasets(self, cfg):
+ """
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
+ Download dataset and annotations automatically if not exist.
+
+ Args:
+ cfg (common.config.Config): _description_
+
+ Returns:
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
+ """
+
+ datasets = dict()
+
+ datasets_config = cfg.datasets_cfg
+
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
+
+ for name in datasets_config:
+ dataset_config = datasets_config[name]
+
+ builder = registry.get_builder_class(name)(dataset_config)
+ dataset = builder.build_datasets()
+
+ if "train" in dataset and "sample_ratio" in dataset_config:
+ dataset["train"].sample_ratio = float(dataset_config.sample_ratio)
+
+ datasets[name] = dataset
+
+ return datasets
+
+ def train_step(self, model, samples):
+ loss_dict = model(samples)
+ loss = loss_dict["loss"]
+ return loss, loss_dict
+
+ def valid_step(self, model, samples):
+ raise NotImplementedError
+
+ def before_evaluation(self, model, dataset, **kwargs):
+ model.before_evaluation(dataset=dataset, task_type=type(self))
+
+ def after_evaluation(self, **kwargs):
+ pass
+
+ def inference_step(self):
+ raise NotImplementedError
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation"
+ # TODO make it configurable
+ print_freq = 10
+
+ results = []
+
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+
+ eval_output = self.valid_step(model=model, samples=samples)
+ results.extend(eval_output)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return results
+
+ def train_epoch(
+ self,
+ epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ iters_per_epoch=len(data_loader),
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def train_iters(
+ self,
+ epoch,
+ start_iters,
+ iters_per_inner_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ start_iters=start_iters,
+ iters_per_epoch=iters_per_inner_epoch,
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def _train_inner_loop(
+ self,
+ epoch,
+ iters_per_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ start_iters=None,
+ log_freq=50,
+ cuda_enabled=False,
+ accum_grad_iters=1,
+ ):
+ """
+ An inner training loop compatible with both epoch-based and iter-based training.
+
+ When using epoch-based, training stops after one epoch; when using iter-based,
+ training stops after #iters_per_epoch iterations.
+ """
+ use_amp = scaler is not None
+
+ if not hasattr(data_loader, "__next__"):
+ # convert to iterator if not already
+ data_loader = iter(data_loader)
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
+
+ # if iter-based runner, schedule lr based on inner epoch.
+ logging.info(
+ "Start training epoch {}, {} iters per inner epoch.".format(
+ epoch, iters_per_epoch
+ )
+ )
+ header = "Train: data epoch: [{}]".format(epoch)
+ if start_iters is None:
+ # epoch-based runner
+ inner_epoch = epoch
+ else:
+ # In iter-based runner, we schedule the learning rate based on iterations.
+ inner_epoch = start_iters // iters_per_epoch
+ header = header + "; inner epoch [{}]".format(inner_epoch)
+
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
+ if i >= iters_per_epoch:
+ break
+
+ samples = next(data_loader)
+
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+ samples.update(
+ {
+ "epoch": inner_epoch,
+ "num_iters_per_epoch": iters_per_epoch,
+ "iters": i,
+ }
+ )
+
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ loss, loss_dict = self.train_step(model=model, samples=samples)
+ loss /= accum_grad_iters # TODO: not affect loss_dict values for logging
+
+ # after_train_step()
+ if use_amp:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ # update gradients every accum_grad_iters iterations
+
+ if (i + 1) % accum_grad_iters == 0:
+ if use_amp:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+
+ metric_logger.update(**loss_dict)
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # after train_epoch()
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
+ return {
+ k: "{:.3f}".format(meter.global_avg)
+ for k, meter in metric_logger.meters.items()
+ }
+
+ @staticmethod
+ def save_result(result, result_dir, filename, remove_duplicate=""):
+ import json
+
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
+ )
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
+
+ json.dump(result, open(result_file, "w"))
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ if is_main_process():
+ logging.warning("rank %d starts merging results." % get_rank())
+ # combine results from all processes
+ result = []
+
+ for rank in range(get_world_size()):
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, rank)
+ )
+ res = json.load(open(result_file, "r"))
+ result += res
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in result:
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ json.dump(result, open(final_result_file, "w"))
+ print("result file saved to %s" % final_result_file)
+
+ return final_result_file
diff --git a/unimernet/tasks/unimernet_train.py b/unimernet/tasks/unimernet_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2b1b5da6ded9197edcea87a6a4ffded817eb5d
--- /dev/null
+++ b/unimernet/tasks/unimernet_train.py
@@ -0,0 +1,167 @@
+import torch
+import evaluate
+import random
+
+from unimernet.common.registry import registry
+from unimernet.tasks.base_task import BaseTask
+from unimernet.common.dist_utils import main_process
+import os.path as osp
+import json
+import numpy as np
+from torchtext.data import metrics
+from rapidfuzz.distance import Levenshtein
+
+
+@registry.register_task("unimernet_train")
+class UniMERNet_Train(BaseTask):
+
+ def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
+ super(UniMERNet_Train, self).__init__()
+ self.temperature = temperature
+ self.do_sample = do_sample
+ self.top_p = top_p
+ self.evaluate = evaluate
+ self.agg_metric = agg_metric
+
+ self.report_metric = report_metric
+
+ @classmethod
+ def setup_task(cls, cfg):
+ run_cfg = cfg.run_cfg
+ generate_cfg = run_cfg.generate_cfg
+
+ temperature = generate_cfg.get('temperature', .2)
+ do_sample = generate_cfg.get("do_sample", False)
+ top_p = generate_cfg.get("top_p", 0.95)
+
+ evaluate = run_cfg.evaluate
+ report_metric = run_cfg.get("report_metric", True)
+ agg_metric = run_cfg.get("agg_metric", "edit_distance")
+
+ return cls(
+ temperature=temperature,
+ do_sample=do_sample,
+ top_p=top_p,
+ evaluate=evaluate,
+ report_metric=report_metric,
+ agg_metric=agg_metric,
+ )
+
+ def valid_step(self, model, samples):
+ results = []
+ image, text = samples["image"], samples["text_input"]
+ preds = model.generate(
+ samples,
+ temperature=self.temperature,
+ do_sample=self.do_sample,
+ top_p=self.top_p
+ )
+ pred_tokens = preds["pred_tokens"]
+ pred_strs = preds["pred_str"]
+ pred_ids = preds["pred_ids"] # [b, n-1]
+
+ truth_inputs = model.tokenizer.tokenize(text)
+ truth_ids = truth_inputs["input_ids"][:, 1:]
+ truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
+ truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])
+
+ ids = samples["id"]
+
+ for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
+ pred_ids, truth_tokens,
+ truth_strs, truth_ids, ids):
+ pred_id = pred_id.tolist()
+ truth_id = truth_id.tolist()
+ shape_diff = len(pred_id) - len(truth_id)
+ if shape_diff < 0:
+ pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
+ else:
+ truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
+ pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
+ mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
+ tok_acc = (pred_id == truth_id)[mask].float().mean().item()
+
+ this_item = {
+ "pred_token": pred_token,
+ "pred_str": pred_str,
+ "truth_str": truth_str,
+ "truth_token": truth_token,
+ "token_acc": tok_acc,
+ "id": id_
+ }
+ results.append(this_item)
+ return results
+
+ def after_evaluation(self, val_result, split_name, epoch, **kwargs):
+ eval_result_file = self.save_result(
+ result=val_result,
+ result_dir=registry.get_path("result_dir"),
+ filename="{}_epoch{}".format(split_name, epoch),
+ remove_duplicate="id",
+ )
+
+ if self.report_metric:
+ metrics = self._report_metrics(
+ eval_result_file=eval_result_file, split_name=split_name
+ )
+ else:
+ metrics = {"agg_metrics": 0.0}
+
+ return metrics
+
+ @main_process
+ def _report_metrics(self, eval_result_file, split_name):
+
+ with open(eval_result_file) as f:
+ results = json.load(f)
+
+ edit_dists = []
+ all_pred_tokens = []
+ all_truth_tokens = []
+ all_pred_strs = []
+ all_truth_strs = []
+ token_accs = []
+ for result in results:
+ pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
+ "truth_token"], result["truth_str"], result["token_acc"]
+
+ if len(truth_str) > 0:
+ norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
+ edit_dists.append(norm_edit_dist)
+
+ all_pred_tokens.append(pred_token)
+ all_truth_tokens.append([truth_token])
+ all_pred_strs.append(pred_str)
+ all_truth_strs.append(truth_str)
+ token_accs.append(tok_acc)
+
+ # bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
+ bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
+ bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
+ bleu_score = bleu_results['bleu']
+
+ edit_distance = np.mean(edit_dists)
+ token_accuracy = np.mean(token_accs)
+ eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}
+
+ log_stats = {split_name: {k: v for k, v in eval_ret.items()}}
+
+ with open(
+ osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+ ) as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ coco_res = {k: v for k, v in eval_ret.items()}
+ # agg_metrics = sum([v for v in eval_ret.values()])
+ if "edit" in self.agg_metric.lower(): # edit_distance
+ agg_metrics = (1 - edit_distance) * 100
+ elif "bleu" in self.agg_metric.lower(): # bleu_score
+ agg_metrics = bleu_score * 100
+ elif "token" in self.agg_metric.lower(): # token_accuracy
+ agg_metrics = token_accuracy * 100
+ else:
+ raise ValueError(f"Invalid metrics: '{self.agg_metric}'")
+
+ coco_res["agg_metrics"] = agg_metrics
+
+ return coco_res