diff --git a/app_lhm.py b/app_lhm.py index e929b8bd9e1066a8f58c73569136429d0c655111..24adb21e6a94a0e8ebb6d3c59e23c74191268184 100644 --- a/app_lhm.py +++ b/app_lhm.py @@ -22,24 +22,24 @@ import base64 import subprocess import os -def install_cuda_toolkit(): -# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" -# # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" -# CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) -# subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) -# subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) -# subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) - - os.environ["CUDA_HOME"] = "/usr/local/cuda" - os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) - os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( - os.environ["CUDA_HOME"], - "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], - ) - # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" - -install_cuda_toolkit() +# def install_cuda_toolkit(): +# # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" +# # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" +# # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) +# # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) +# # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) +# # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) + +# os.environ["CUDA_HOME"] = "/usr/local/cuda" +# os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) +# os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( +# os.environ["CUDA_HOME"], +# "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], +# ) +# # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range +# os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + +# install_cuda_toolkit() def launch_pretrained(): from huggingface_hub import snapshot_download, hf_hub_download @@ -54,7 +54,8 @@ def launch_env_not_compile_with_cuda(): os.system("pip install chumpy") os.system("pip uninstall -y basicsr") os.system("pip install git+https://github.com/hitsz-zuoqi/BasicSR/") - os.system("pip install git+https://github.com/hitsz-zuoqi/sam2/") + os.system("pip install -e ./third_party/sam2") + # os.system("pip install git+https://github.com/hitsz-zuoqi/sam2/") # os.system("pip install git+https://github.com/ashawkey/diff-gaussian-rasterization/") # os.system("pip install git+https://github.com/camenduru/simple-knn/") os.system("pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html") @@ -78,8 +79,7 @@ def launch_env_not_compile_with_cuda(): # os.system("mv pytorch3d /usr/local/lib/python3.10/site-packages/") # os.system("mv pytorch3d-0.7.8.dist-info /usr/local/lib/python3.10/site-packages/") -launch_pretrained() -launch_env_not_compile_with_cuda() + # launch_env_compile_with_cuda() def assert_input_image(input_image): @@ -268,5 +268,6 @@ def launch_gradio_app(): if __name__ == '__main__': - + launch_pretrained() + launch_env_not_compile_with_cuda() launch_gradio_app() diff --git a/sam2_configs/__init__.py b/sam2_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sam2_configs/sam2.1_hiera_l.yaml b/sam2_configs/sam2.1_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57d9d08531b9d371f2cb074028824f718d58a775 --- /dev/null +++ b/sam2_configs/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False \ No newline at end of file diff --git a/third_party/sam2/.clang-format b/third_party/sam2/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..39b1b3d603ed0cf6b7f94c9c08067f148f35613f --- /dev/null +++ b/third_party/sam2/.clang-format @@ -0,0 +1,85 @@ +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never diff --git a/third_party/sam2/.github/workflows/check_fmt.yml b/third_party/sam2/.github/workflows/check_fmt.yml new file mode 100644 index 0000000000000000000000000000000000000000..0a29b884af2b5c0bdb71b607e7b8220e879755be --- /dev/null +++ b/third_party/sam2/.github/workflows/check_fmt.yml @@ -0,0 +1,17 @@ +name: SAM2/fmt +on: + pull_request: + branches: + - main +jobs: + ufmt_check: + runs-on: ubuntu-latest + steps: + - name: Check formatting + uses: omnilib/ufmt@action-v1 + with: + path: sam2 tools + version: "2.0.0b2" + python-version: "3.10" + black-version: "24.2.0" + usort-version: "1.0.2" diff --git a/third_party/sam2/.gitignore b/third_party/sam2/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..121d46aa5c1854ee2b2a5085ff2ce22f5a04043f --- /dev/null +++ b/third_party/sam2/.gitignore @@ -0,0 +1,11 @@ +.vscode/ +.DS_Store +__pycache__/ +*-checkpoint.ipynb +.venv +*.egg* +build/* +_C.* +outputs/* +checkpoints/*.pt +demo/backend/checkpoints/*.pt diff --git a/third_party/sam2/.watchmanconfig b/third_party/sam2/.watchmanconfig new file mode 100644 index 0000000000000000000000000000000000000000..9e26dfeeb6e641a33dae4961196235bdb965b21b --- /dev/null +++ b/third_party/sam2/.watchmanconfig @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/third_party/sam2/CODE_OF_CONDUCT.md b/third_party/sam2/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..08b500a221857ec3f451338e80b4a9ab1173a1af --- /dev/null +++ b/third_party/sam2/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/third_party/sam2/CONTRIBUTING.md b/third_party/sam2/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..ad15049f583e1bc9a418686493405875b98c7f0f --- /dev/null +++ b/third_party/sam2/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to segment-anything +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to segment-anything, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/INSTALL.md b/third_party/sam2/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..9480ba1bb52c171cfccc6a078c68abdb49125daa --- /dev/null +++ b/third_party/sam2/INSTALL.md @@ -0,0 +1,189 @@ +## Installation + +### Requirements + +- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. + * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. +- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. +- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. + +Then, install SAM 2 from the root of this repository via +```bash +pip install -e ".[notebooks]" +``` + +Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows: +```bash +# skip the SAM 2 CUDA extension +SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]" +``` +This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases. + +### Building the SAM 2 CUDA extension + +By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.) + +If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases. + +If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows +```bash +pip uninstall -y SAM-2 && \ +rm -f ./sam2/*.so && \ +SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]" +``` + +Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`. + +Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime. + +### Common Installation Issues + +Click each issue for its solutions: + +
+ +I got `ImportError: cannot import name '_C' from 'sam2'` + +
+ +This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails. + +In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77. +
+ +
+ +I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'` + +
+ +This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via +```bash +export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo +export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}" +``` +to manually add `sam2_configs` into your Python's `sys.path`. + +
+ +
+ +I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints + +
+ +This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps: + +1. pull the latest code from the `main` branch of this repo +2. run `pip uninstall -y SAM-2` to uninstall any previous installations +3. then install the latest repo again using `pip install -e ".[notebooks]"` + +In case the steps above still don't resolve the error, please try running in your Python environment the following +```python +from sam2.modeling import sam2_base + +print(sam2_base.__file__) +``` +and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation. + +
+ +
+ +My installation failed with `CUDA_HOME environment variable is not set` + +
+ +This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via +``` +export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path +``` +and rerun the installation. + +Also, you should make sure +``` +python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)' +``` +print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up. + +If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command: +``` +pip install --no-build-isolation -e . +``` + +
+ +
+ +I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors) + +
+ +This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. + +In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. + +We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. +
+ +
+ +I got `CUDA error: no kernel image is available for execution on the device` + +
+ +A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system). + +You can try pulling the latest code from the SAM 2 repo and running the following +``` +export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0` +``` +to manually specify the CUDA capability in the compilation target that matches your GPU. +
+ +
+ +I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors) + +
+ +This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line +```python +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +``` +in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with +```python +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True +``` +to relax the attention kernel setting and use other kernels than Flash Attention. +
+ +
+ +I got `Error compiling objects for extension` + +
+ +You may see error log of: +> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk. + +This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).
+You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py).
+After adding the argument, `get_extension()` will look like this: +```python +def get_extensions(): + srcs = ["sam2/csrc/connected_components.cu"] + compile_args = { + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "-allow-unsupported-compiler" # Add this argument + ], + } + ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + return ext_modules +``` +
diff --git a/third_party/sam2/LICENSE b/third_party/sam2/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/third_party/sam2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/sam2/LICENSE_cctorch b/third_party/sam2/LICENSE_cctorch new file mode 100644 index 0000000000000000000000000000000000000000..23da14a65aad4c5bac18061b80ae6040bb7d2c8c --- /dev/null +++ b/third_party/sam2/LICENSE_cctorch @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/sam2/MANIFEST.in b/third_party/sam2/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..794311fd9854453b134c828c0cb241a7cfdbfc65 --- /dev/null +++ b/third_party/sam2/MANIFEST.in @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +recursive-include sam2 *.yaml #include all config files diff --git a/third_party/sam2/README.md b/third_party/sam2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..85a7eb958bced5495ff990c2bcbe7d99662c660f --- /dev/null +++ b/third_party/sam2/README.md @@ -0,0 +1,224 @@ +# SAM 2: Segment Anything in Images and Videos + +**[AI at Meta, FAIR](https://ai.meta.com/research/)** + +[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/) + +[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)] + +![SAM 2 architecture](assets/model_diagram.png?raw=true) + +**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains. + +![SA-V dataset](assets/sa_v_dataset.jpg?raw=true) + +## Latest updates + +**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking** + +- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference. +- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts. +- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details. + +**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released** + +- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. + * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. +- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. +- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. + +## Installation + +SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: + +```bash +git clone https://github.com/facebookresearch/sam2.git && cd sam2 + +pip install -e . +``` +If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. + +To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by: + +```bash +pip install -e ".[notebooks]" +``` + +Note: +1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. +2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version. +3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases). + +Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions. + +## Getting Started + +### Download Checkpoints + +First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running: + +```bash +cd checkpoints && \ +./download_ckpts.sh && \ +cd .. +``` + +or individually from: + +- [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt) +- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt) +- [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt) +- [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt) + +(note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.) + +Then SAM 2 can be used in a few lines as follows for image and video prediction. + +### Image prediction + +SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting. + +```python +import torch +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +checkpoint = "./checkpoints/sam2.1_hiera_large.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" +predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases. + +SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images. + +### Video prediction + +For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video. + +```python +import torch +from sam2.build_sam import build_sam2_video_predictor + +checkpoint = "./checkpoints/sam2.1_hiera_large.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" +predictor = build_sam2_video_predictor(model_cfg, checkpoint) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos. + +## Load from 🤗 Hugging Face + +Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +## Model Description + +### SAM 2.1 checkpoints + +The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024. +| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | +| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | +| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 | +| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 | +| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 | +| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 | + +### SAM 2 checkpoints + +The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows: + +| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | +| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | +| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 | +| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 | +| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 | +| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 | + +Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config). +## Segment Anything Video Dataset + +See [sav_dataset/README.md](sav_dataset/README.md) for details. + +## Training SAM 2 + +You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started. + +## Web demo for SAM 2 + +We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details. + +## License + +The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/). + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Contributors + +The SAM 2 project was made possible with the help of many contributors (alphabetical): + +Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang. + +Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions. + +## Citing SAM 2 + +If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry. + +```bibtex +@article{ravi2024sam2, + title={SAM 2: Segment Anything in Images and Videos}, + author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, + journal={arXiv preprint arXiv:2408.00714}, + url={https://arxiv.org/abs/2408.00714}, + year={2024} +} +``` diff --git a/third_party/sam2/RELEASE_NOTES.md b/third_party/sam2/RELEASE_NOTES.md new file mode 100644 index 0000000000000000000000000000000000000000..ee65ae7f4a51f1c7fce81204c5dc94467882d366 --- /dev/null +++ b/third_party/sam2/RELEASE_NOTES.md @@ -0,0 +1,27 @@ +## SAM 2 release notes + +### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking + +- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). + * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. + * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. + * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. + * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. +- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: + * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. + * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). + * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class. + +### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released + +- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. + * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. +- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. +- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. + +### 07/29/2024 -- SAM 2 is released + +- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos. + * SAM 2 code: https://github.com/facebookresearch/sam2 + * SAM 2 demo: https://sam2.metademolab.com/ + * SAM 2 paper: https://arxiv.org/abs/2408.00714 diff --git a/third_party/sam2/assets/model_diagram.png b/third_party/sam2/assets/model_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..61b8b7c08722f3cf433acaf8001013aa30ce62e9 Binary files /dev/null and b/third_party/sam2/assets/model_diagram.png differ diff --git a/third_party/sam2/assets/sa_v_dataset.jpg b/third_party/sam2/assets/sa_v_dataset.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77af3b1a426e9e429b2a28ff7e5d7c13a40adfd0 Binary files /dev/null and b/third_party/sam2/assets/sa_v_dataset.jpg differ diff --git a/third_party/sam2/backend.Dockerfile b/third_party/sam2/backend.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..54a32967b0e053ae2e7d16c734928636ef46db7b --- /dev/null +++ b/third_party/sam2/backend.Dockerfile @@ -0,0 +1,64 @@ +ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime +ARG MODEL_SIZE=base_plus + +FROM ${BASE_IMAGE} + +# Gunicorn environment variables +ENV GUNICORN_WORKERS=1 +ENV GUNICORN_THREADS=2 +ENV GUNICORN_PORT=5000 + +# SAM 2 environment variables +ENV APP_ROOT=/opt/sam2 +ENV PYTHONUNBUFFERED=1 +ENV SAM2_BUILD_CUDA=0 +ENV MODEL_SIZE=${MODEL_SIZE} + +# Install system requirements +RUN apt-get update && apt-get install -y --no-install-recommends \ + ffmpeg \ + libavutil-dev \ + libavcodec-dev \ + libavformat-dev \ + libswscale-dev \ + pkg-config \ + build-essential \ + libffi-dev + +COPY setup.py . +COPY README.md . + +RUN pip install --upgrade pip setuptools +RUN pip install -e ".[interactive-demo]" + +# https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707 +RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg + +# Make app directory. This directory will host all files required for the +# backend and SAM 2 inference files. +RUN mkdir ${APP_ROOT} + +# Copy backend server files +COPY demo/backend/server ${APP_ROOT}/server + +# Copy SAM 2 inference files +COPY sam2 ${APP_ROOT}/server/sam2 + +# Download SAM 2.1 checkpoints +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt +ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt + +WORKDIR ${APP_ROOT}/server + +# https://pythonspeed.com/articles/gunicorn-in-docker/ +CMD gunicorn --worker-tmp-dir /dev/shm \ + --worker-class gthread app:app \ + --log-level info \ + --access-logfile /dev/stdout \ + --log-file /dev/stderr \ + --workers ${GUNICORN_WORKERS} \ + --threads ${GUNICORN_THREADS} \ + --bind 0.0.0.0:${GUNICORN_PORT} \ + --timeout 60 diff --git a/third_party/sam2/checkpoints/download_ckpts.sh b/third_party/sam2/checkpoints/download_ckpts.sh new file mode 100755 index 0000000000000000000000000000000000000000..eedee8eee153f17c6db3b92de5492fa0a11ec3b7 --- /dev/null +++ b/third_party/sam2/checkpoints/download_ckpts.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Use either wget or curl to download the checkpoints +if command -v wget &> /dev/null; then + CMD="wget" +elif command -v curl &> /dev/null; then + CMD="curl -L -O" +else + echo "Please install wget or curl to download the checkpoints." + exit 1 +fi + +# Define the URLs for SAM 2 checkpoints +# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824" +# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt" +# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt" +# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt" +# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt" + +# Download each of the four checkpoints using wget +# echo "Downloading sam2_hiera_tiny.pt checkpoint..." +# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; } + +# echo "Downloading sam2_hiera_small.pt checkpoint..." +# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; } + +# echo "Downloading sam2_hiera_base_plus.pt checkpoint..." +# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; } + +# echo "Downloading sam2_hiera_large.pt checkpoint..." +# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; } + +# Define the URLs for SAM 2.1 checkpoints +SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824" +sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt" +sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt" +sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt" +sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt" + +# SAM 2.1 checkpoints +echo "Downloading sam2.1_hiera_tiny.pt checkpoint..." +$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; } + +echo "Downloading sam2.1_hiera_small.pt checkpoint..." +$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; } + +echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..." +$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; } + +echo "Downloading sam2.1_hiera_large.pt checkpoint..." +$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; } + +echo "All checkpoints are downloaded successfully." diff --git a/third_party/sam2/docker-compose.yaml b/third_party/sam2/docker-compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a5395a585daa7d5a6e0e97d3a30b48f225fb2cf --- /dev/null +++ b/third_party/sam2/docker-compose.yaml @@ -0,0 +1,42 @@ +services: + frontend: + image: sam2/frontend + build: + context: ./demo/frontend + dockerfile: frontend.Dockerfile + ports: + - 7262:80 + + backend: + image: sam2/backend + build: + context: . + dockerfile: backend.Dockerfile + ports: + - 7263:5000 + volumes: + - ./demo/data/:/data/:rw + environment: + - SERVER_ENVIRONMENT=DEV + - GUNICORN_WORKERS=1 + # Inference API needs to have at least 2 threads to handle an incoming + # parallel cancel propagation request + - GUNICORN_THREADS=2 + - GUNICORN_PORT=5000 + - API_URL=http://localhost:7263 + - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 + # # ffmpeg/video encode settings + - FFMPEG_NUM_THREADS=1 + - VIDEO_ENCODE_CODEC=libx264 + - VIDEO_ENCODE_CRF=23 + - VIDEO_ENCODE_FPS=24 + - VIDEO_ENCODE_MAX_WIDTH=1280 + - VIDEO_ENCODE_MAX_HEIGHT=720 + - VIDEO_ENCODE_VERBOSE=False + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/third_party/sam2/pyproject.toml b/third_party/sam2/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f84317dbbfa6ba4f2d972cab2e2e0d0bdf07f003 --- /dev/null +++ b/third_party/sam2/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "torch>=2.5.1", + ] +build-backend = "setuptools.build_meta" diff --git a/third_party/sam2/sam2/__init__.py b/third_party/sam2/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d4f2e69b012ad1ca4aa418864c4aea6d362408 --- /dev/null +++ b/third_party/sam2/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("sam2_configs", version_base="1.2") diff --git a/third_party/sam2/sam2/automatic_mask_generator.py b/third_party/sam2/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..065e469e27c2d3af40d51d072031e828692c799b --- /dev/null +++ b/third_party/sam2/sam2/automatic_mask_generator.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + **kwargs, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/third_party/sam2/sam2/benchmark.py b/third_party/sam2/sam2/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..6519534c8619e04b9a632859a5128ad2cee34c13 --- /dev/null +++ b/third_party/sam2/sam2/benchmark.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time + +import numpy as np +import torch +from tqdm import tqdm + +from sam2.build_sam import build_sam2_video_predictor + +# Only cuda supported +assert torch.cuda.is_available() +device = torch.device("cuda") + +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# Config and checkpoint +sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" + +# Build video predictor with vos_optimized=True setting +predictor = build_sam2_video_predictor( + model_cfg, sam2_checkpoint, device=device, vos_optimized=True +) + + +# Initialize with video +video_dir = "notebooks/videos/bedroom" +# scan all the JPEG frame names in this directory +frame_names = [ + p + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) +inference_state = predictor.init_state(video_path=video_dir) + + +# Number of runs, warmup etc +warm_up, runs = 5, 25 +verbose = True +num_frames = len(frame_names) +total, count = 0, 0 +torch.cuda.empty_cache() + +# We will select an object with a click. +# See video_predictor_example.ipynb for more detailed explanation +ann_frame_idx, ann_obj_id = 0, 1 +# Add a positive click at (x, y) = (210, 350) +# For labels, `1` means positive click +points = np.array([[210, 350]], dtype=np.float32) +labels = np.array([1], np.int32) + +_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, +) + +# Warmup and then average FPS over several runs +with torch.autocast("cuda", torch.bfloat16): + with torch.inference_mode(): + for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): + start = time.time() + # Start tracking + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in predictor.propagate_in_video(inference_state): + pass + + end = time.time() + total += end - start + count += 1 + if i == warm_up - 1: + print("Warmup FPS: ", count * num_frames / total) + total = 0 + count = 0 + +print("FPS: ", count * num_frames / total) diff --git a/third_party/sam2/sam2/build_sam.py b/third_party/sam2/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3bef1e566d86c3ba0fd75f425530bc6505e9bf --- /dev/null +++ b/third_party/sam2/sam2/build_sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) + + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "configs/sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "configs/sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "configs/sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "configs/sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "configs/sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "configs/sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "configs/sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "configs/sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + vos_optimized=False, + **kwargs, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if vos_optimized: + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", + "++model.compile_image_encoder=True", # Let sam2_base handle this + ] + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7172f9b0b663aaaace97fed7e2a08db75150461 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23073ea7a95901be656b3c6d1a66ce8736ab7ad3 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd8d40465b18b3de39b0a565aca712306306c4ed --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -0,0 +1,119 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e762aec932f26436d13798f3feb3ec82c360a943 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -0,0 +1,121 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b6faa79f47ee576faf007bffd23fb6649bd881d --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml @@ -0,0 +1,339 @@ +# @package _global_ + +scratch: + resolution: 1024 + train_batch_size: 1 + num_train_workers: 10 + num_frames: 8 + max_num_objects: 3 + base_lr: 5.0e-6 + vision_lr: 3.0e-06 + phases_per_epoch: 1 + num_epochs: 40 + +dataset: + # PATHS to Dataset + img_folder: null # PATH to MOSE JPEGImages folder + gt_folder: null # PATH to MOSE Annotations folder + file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training + multiplier: 2 + +# Video transforms +vos: + train_transforms: + - _target_: training.dataset.transforms.ComposeAPI + transforms: + - _target_: training.dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: training.dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: training.dataset.transforms.RandomResizeAPI + sizes: ${scratch.resolution} + square: true + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: training.dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: training.dataset.transforms.ToTensorAPI + - _target_: training.dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +trainer: + _target_: training.trainer.Trainer + mode: train_only + max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} + accelerator: cuda + seed_value: 123 + + model: + _target_: training.model.sam2.SAM2Train + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + drop_path_rate: 0.1 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: ${scratch.resolution} + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # compile_image_encoder: False + + ####### Training specific params ####### + # box/point input and corrections + prob_to_use_pt_input_for_train: 0.5 + prob_to_use_pt_input_for_eval: 0.0 + prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points + prob_to_use_box_input_for_eval: 0.0 + prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors + num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) + num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame + rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 + add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) + # maximum 2 initial conditioning frames + num_init_cond_frames_for_train: 2 + rand_init_cond_frames_for_train: True # random 1~2 + num_correction_pt_per_frame: 7 + use_act_ckpt_iterative_pt_sampling: false + + + + num_init_cond_frames_for_eval: 1 # only mask on the first frame + forward_backbone_per_frame_for_eval: True + + + data: + train: + _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset + phases_per_epoch: ${scratch.phases_per_epoch} + batch_sizes: + - ${scratch.train_batch_size} + + datasets: + - _target_: training.dataset.utils.RepeatFactorWrapper + dataset: + _target_: training.dataset.utils.ConcatDataset + datasets: + - _target_: training.dataset.vos_dataset.VOSDataset + transforms: ${vos.train_transforms} + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.PNGRawDataset + img_folder: ${dataset.img_folder} + gt_folder: ${dataset.gt_folder} + file_list_txt: ${dataset.file_list_txt} + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: ${scratch.num_frames} + max_num_objects: ${scratch.max_num_objects} + multiplier: ${dataset.multiplier} + shuffle: True + num_workers: ${scratch.num_train_workers} + pin_memory: True + drop_last: True + collate_fn: + _target_: training.utils.data_utils.collate_fn + _partial_: true + dict_key: all + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: training.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: training.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: 0.9 + apply_to: 'image_encoder.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.base_lr} + end_value: ${divide:${scratch.base_lr},10} + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.vision_lr} + end_value: ${divide:${scratch.vision_lr},10} + param_names: + - 'image_encoder.*' + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.1 + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + loss: + all: + _target_: training.loss_fns.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + + distributed: + backend: nccl + find_unused_parameters: True + + logging: + tensorboard_writer: + _target_: training.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + log_dir: ${launcher.experiment_log_dir}/logs + log_freq: 10 + + # initialize from a SAM 2 checkpoint + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + model_weight_initializer: + _partial_: True + _target_: training.utils.checkpoint_utils.load_state_dict_into_model + strict: True + ignore_unexpected_keys: null + ignore_missing_keys: null + + state_dict: + _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels + checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint + ckpt_state_dict_keys: ['model'] + +launcher: + num_nodes: 1 + gpus_per_node: 8 + experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} + +# SLURM args if running on a cluster +submitit: + partition: null + account: null + qos: null + cpus_per_task: 10 + use_cluster: false + timeout_hour: 24 + name: null + port_range: [10000, 65000] + diff --git a/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml b/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f435af02fc88e2d3b7bff06f8cf8013cc079c24 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml b/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1092802b1d24be6fedf78939f45b0d021d4ec560 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml b/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..174e414f1467d80e94a34e9525dc373058f8caaa --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml b/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..121447aabd5318fac20efc2bc00d7c406ca26f01 --- /dev/null +++ b/third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/third_party/sam2/sam2/csrc/connected_components.cu b/third_party/sam2/sam2/csrc/connected_components.cu new file mode 100644 index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45 --- /dev/null +++ b/third_party/sam2/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/third_party/sam2/sam2/modeling/__init__.py b/third_party/sam2/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/sam2/modeling/backbones/__init__.py b/third_party/sam2/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/sam2/modeling/backbones/hieradet.py b/third_party/sam2/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..19ac77b61d8e1345a301686d39ef2ab6e4b035fb --- /dev/null +++ b/third_party/sam2/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/third_party/sam2/sam2/modeling/backbones/image_encoder.py b/third_party/sam2/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c --- /dev/null +++ b/third_party/sam2/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/third_party/sam2/sam2/modeling/backbones/utils.py b/third_party/sam2/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..930b1b7622e7b0e7270120dcafccc242ef0f4f28 --- /dev/null +++ b/third_party/sam2/sam2/modeling/backbones/utils.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.reshape( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :] + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/third_party/sam2/sam2/modeling/memory_attention.py b/third_party/sam2/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0b07f9d87e3d8194ca5e11fc20f01604d591a59d --- /dev/null +++ b/third_party/sam2/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/third_party/sam2/sam2/modeling/memory_encoder.py b/third_party/sam2/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f60202dfaba87232c3870fb2101b5322a119d985 --- /dev/null +++ b/third_party/sam2/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/third_party/sam2/sam2/modeling/position_encoding.py b/third_party/sam2/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..2241d4cf1a4495b4c67dc35cbed1c606357b9b7a --- /dev/null +++ b/third_party/sam2/sam2/modeling/position_encoding.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + # Following settings only relevant + # for warmping up cache for compilation + warmup_cache: bool = True, + image_size: int = 1024, + strides: Tuple[int] = (4, 8, 16, 32), + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + if warmup_cache and torch.cuda.is_available(): + # Warmup cache for cuda, to help with compilation + device = torch.device("cuda") + for stride in strides: + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def _pe(self, B, device, *cache_key): + H, W = cache_key + if cache_key in self.cache: + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) + + y_embed = ( + torch.arange(1, H + 1, dtype=torch.float32, device=device) + .view(1, -1, 1) + .repeat(B, 1, W) + ) + x_embed = ( + torch.arange(1, W + 1, dtype=torch.float32, device=device) + .view(1, 1, -1) + .repeat(B, H, 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + B = x.shape[0] + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/third_party/sam2/sam2/modeling/sam/__init__.py b/third_party/sam2/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/sam2/modeling/sam/mask_decoder.py b/third_party/sam2/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9bebc0366b2703ffcb80a44bfd19cce8339b4fed --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/third_party/sam2/sam2/modeling/sam/prompt_encoder.py b/third_party/sam2/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c57876264b51f8c5236867359350e32d590efcb5 --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/third_party/sam2/sam2/modeling/sam/transformer.py b/third_party/sam2/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f9fe9a3fbc5cce4f1abe8ee0ae3a8602bbe2ff1b --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam/transformer.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = ( + freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis + ) + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/third_party/sam2/sam2/modeling/sam2_base.py b/third_party/sam2/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f4e515b0d161942bf2bb64560056b3efbe6dac --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam2_base.py @@ -0,0 +1,909 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = image_encoder.neck.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list).to( + device=device, non_blocking=True + ) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/third_party/sam2/sam2/modeling/sam2_utils.py b/third_party/sam2/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e16caae3a9a49e451b2d03d1ee60c47f8e9ed23c --- /dev/null +++ b/third_party/sam2/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/third_party/sam2/sam2/sam2_hiera_b+.yaml b/third_party/sam2/sam2/sam2_hiera_b+.yaml new file mode 120000 index 0000000000000000000000000000000000000000..998d9c98c9ff4e8ddd55deff72aa0d9067977418 --- /dev/null +++ b/third_party/sam2/sam2/sam2_hiera_b+.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_b+.yaml \ No newline at end of file diff --git a/third_party/sam2/sam2/sam2_hiera_l.yaml b/third_party/sam2/sam2/sam2_hiera_l.yaml new file mode 120000 index 0000000000000000000000000000000000000000..c0e7e58e1951d5c55a3a3ebe6b803dd814cf9d86 --- /dev/null +++ b/third_party/sam2/sam2/sam2_hiera_l.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_l.yaml \ No newline at end of file diff --git a/third_party/sam2/sam2/sam2_hiera_s.yaml b/third_party/sam2/sam2/sam2_hiera_s.yaml new file mode 120000 index 0000000000000000000000000000000000000000..41896a26beb2aa831d18b0bf3c349ed43deeef68 --- /dev/null +++ b/third_party/sam2/sam2/sam2_hiera_s.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_s.yaml \ No newline at end of file diff --git a/third_party/sam2/sam2/sam2_hiera_t.yaml b/third_party/sam2/sam2/sam2_hiera_t.yaml new file mode 120000 index 0000000000000000000000000000000000000000..71ff3abbb1e11f8b82100a0a1d63cb267eefe52a --- /dev/null +++ b/third_party/sam2/sam2/sam2_hiera_t.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_t.yaml \ No newline at end of file diff --git a/third_party/sam2/sam2/sam2_image_predictor.py b/third_party/sam2/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..41ce53af5924504c07216df52b2d2eefaeec7ae9 --- /dev/null +++ b/third_party/sam2/sam2/sam2_image_predictor.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/third_party/sam2/sam2/sam2_video_predictor.py b/third_party/sam2/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7e1a01c4d6e89db0453ce982ea8a31b16651c8 --- /dev/null +++ b/third_party/sam2/sam2/sam2_video_predictor.py @@ -0,0 +1,1223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["frames_tracked_per_obj"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # We always allow adding new objects (including after tracking starts). + allow_new_object = True + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["frames_tracked_per_obj"][obj_idx] = {} + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = obj_frames_tracked[frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in obj_frames_tracked + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = obj_frames_tracked[frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + + return consolidated_out + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Check and make sure that every object has received input points or masks. + batch_size = self._get_obj_num(inference_state) + if batch_size == 0: + raise RuntimeError( + "No input points or masks are provided for any object; please add inputs first." + ) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = ( + "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + ) + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + for frame_idx, out in obj_temp_output_dict[storage_key].items(): + # Run memory encoder on the temporary outputs (if the memory feature is missing) + if out["maskmem_features"] is None: + high_res_masks = torch.nn.functional.interpolate( + out["pred_masks"].to(inference_state["device"]), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + high_res_masks=high_res_masks, + object_score_logits=out["object_score_logits"], + # these frames are what the user interacted with + is_mask_from_pts=True, + ) + out["maskmem_features"] = maskmem_features + out["maskmem_pos_enc"] = maskmem_pos_enc + + obj_output_dict[storage_key][frame_idx] = out + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) + + # clear temporary outputs in `temp_output_dict_per_obj` + obj_temp_output_dict[storage_key].clear() + + # check and make sure that every object has received input points or masks + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + if len(obj_output_dict["cond_frame_outputs"]) == 0: + obj_id = self._obj_idx_to_id(inference_state, obj_idx) + raise RuntimeError( + f"No input points or masks are provided for object id {obj_id}; please add inputs first." + ) + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min( + t + for obj_output_dict in inference_state["output_dict_per_obj"].values() + for t in obj_output_dict["cond_frame_outputs"] + ) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + pred_masks_per_obj = [None] * batch_size + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in obj_output_dict["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = obj_output_dict[storage_key][frame_idx] + device = inference_state["device"] + pred_masks = current_out["pred_masks"].to(device, non_blocking=True) + if self.clear_non_cond_mem_around_input: + # clear non-conditioning memory of the surrounding frames + self._clear_obj_non_cond_mem_around_input( + inference_state, frame_idx, obj_idx + ) + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + obj_output_dict[storage_key][frame_idx] = current_out + + inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { + "reverse": reverse + } + pred_masks_per_obj[obj_idx] = pred_masks + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: + all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) + else: + all_pred_masks = pred_masks_per_obj[0] + _, video_res_masks = self._get_orig_video_res_output( + inference_state, all_pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + obj_output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + inference_state["frames_tracked_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["frames_tracked_per_obj"].values(): + v.clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + _map_keys(inference_state["frames_tracked_per_obj"]) + + # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + batch_size = self._get_obj_num(inference_state) + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + + +class SAM2VideoPredictorVOS(SAM2VideoPredictor): + """Optimized for the VOS setting""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._compile_all_components() + + def _compile_all_components(self): + print("Compiling all components for VOS setting. First time may be very slow.") + self.memory_encoder.forward = torch.compile( + self.memory_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + self.memory_attention.forward = torch.compile( + self.memory_attention.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, # Num. of memories varies + ) + + self.sam_prompt_encoder.forward = torch.compile( + self.sam_prompt_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + self.sam_mask_decoder.forward = torch.compile( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + def forward_image(self, img_batch: torch.Tensor): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the backbone features and pos encoding to enable compilation. + """ + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + # Clone to help torch.compile + for i in range(len(backbone_out["backbone_fpn"])): + backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() + backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ + i + ].clone() + return backbone_out + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the outputs of prompt_encoder and mask_decoder to enable compilation. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = sparse_embeddings.clone() + dense_embeddings = dense_embeddings.clone() + image_pe = self.sam_prompt_encoder.get_dense_pe().clone() + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = low_res_multimasks.clone() + ious = ious.clone() + sam_output_tokens = sam_output_tokens.clone() + object_score_logits = object_score_logits.clone() + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the memories and their pos enc to enable compilation. + """ + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + # Clone the feats and pos_enc to enable compilation + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc diff --git a/third_party/sam2/sam2/sam2_video_predictor_legacy.py b/third_party/sam2/sam2/sam2_video_predictor_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e01ccf972491904b013526333826b337354db1 --- /dev/null +++ b/third_party/sam2/sam2/sam2_video_predictor_legacy.py @@ -0,0 +1,1172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/third_party/sam2/sam2/utils/__init__.py b/third_party/sam2/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/sam2/utils/amg.py b/third_party/sam2/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e --- /dev/null +++ b/third_party/sam2/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/third_party/sam2/sam2/utils/misc.py b/third_party/sam2/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4 --- /dev/null +++ b/third_party/sam2/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/third_party/sam2/sam2/utils/transforms.py b/third_party/sam2/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..cc17bebfab104b659c5469e8434cf357ae7e24b6 --- /dev/null +++ b/third_party/sam2/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/third_party/sam2/setup.py b/third_party/sam2/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..78a634cddb19615c45601681ffbcd1f29af66f47 --- /dev/null +++ b/third_party/sam2/setup.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os + +from setuptools import find_packages, setup + +# Package metadata +NAME = "SAM-2" +VERSION = "1.0" +DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" +URL = "https://github.com/facebookresearch/sam2" +AUTHOR = "Meta AI" +AUTHOR_EMAIL = "segment-anything@meta.com" +LICENSE = "Apache 2.0" + +# Read the contents of README file +with open("README.md", "r", encoding="utf-8") as f: + LONG_DESCRIPTION = f.read() + +# Required dependencies +REQUIRED_PACKAGES = [ + "torch>=2.5.1", + "torchvision>=0.20.1", + "numpy>=1.24.4", + "tqdm>=4.66.1", + "hydra-core>=1.3.2", + "iopath>=0.1.10", + "pillow>=9.4.0", +] + +EXTRA_PACKAGES = { + "notebooks": [ + "matplotlib>=3.9.1", + "jupyter>=1.0.0", + "opencv-python>=4.7.0", + "eva-decord>=0.6.1", + ], + "interactive-demo": [ + "Flask>=3.0.3", + "Flask-Cors>=5.0.0", + "av>=13.0.0", + "dataclasses-json>=0.6.7", + "eva-decord>=0.6.1", + "gunicorn>=23.0.0", + "imagesize>=1.4.1", + "pycocotools>=2.0.8", + "strawberry-graphql>=0.243.0", + ], + "dev": [ + "black==24.2.0", + "usort==1.0.2", + "ufmt==2.0.0b2", + "fvcore>=0.1.5.post20221221", + "pandas>=2.2.2", + "scikit-image>=0.24.0", + "tensorboard>=2.17.0", + "pycocotools>=2.0.8", + "tensordict>=0.6.0", + "opencv-python>=4.7.0", + "submitit>=1.5.1", + ], +} + +# By default, we also build the SAM 2 CUDA extension. +# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. +BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" +# By default, we allow SAM 2 installation to proceed even with build errors. +# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. +BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" + +# Catch and skip errors during extension building and print a warning message +# (note that this message only shows up under verbose build mode +# "pip install -v -e ." or "python setup.py build_ext -v") +CUDA_ERROR_MSG = ( + "{}\n\n" + "Failed to build the SAM 2 CUDA extension due to the error above. " + "You can still use SAM 2 and it's OK to ignore the error above, although some " + "post-processing functionality may be limited (which doesn't affect the results in most cases; " + "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" +) + + +def get_extensions(): + if not BUILD_CUDA: + return [] + + try: + from torch.utils.cpp_extension import CUDAExtension + + srcs = ["sam2/csrc/connected_components.cu"] + compile_args = { + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + } + ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + except Exception as e: + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + ext_modules = [] + else: + raise e + + return ext_modules + + +try: + from torch.utils.cpp_extension import BuildExtension + + class BuildExtensionIgnoreErrors(BuildExtension): + + def finalize_options(self): + try: + super().finalize_options() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def build_extensions(self): + try: + super().build_extensions() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def get_ext_filename(self, ext_name): + try: + return super().get_ext_filename(ext_name) + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + return "_C.so" + + cmdclass = { + "build_ext": ( + BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) + if BUILD_ALLOW_ERRORS + else BuildExtension.with_options(no_python_abi_suffix=True) + ) + } +except Exception as e: + cmdclass = {} + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + else: + raise e + + +# Setup configuration +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=LONG_DESCRIPTION, + long_description_content_type="text/markdown", + url=URL, + author=AUTHOR, + author_email=AUTHOR_EMAIL, + license=LICENSE, + packages=find_packages(exclude="notebooks"), + include_package_data=True, + install_requires=REQUIRED_PACKAGES, + extras_require=EXTRA_PACKAGES, + python_requires=">=3.10.0", + ext_modules=get_extensions(), + cmdclass=cmdclass, +) diff --git a/third_party/sam2/tools/README.md b/third_party/sam2/tools/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1dd0e8a754f4bf27ee321084076f3ebdb2285450 --- /dev/null +++ b/third_party/sam2/tools/README.md @@ -0,0 +1,36 @@ +## SAM 2 toolkits + +This directory provides toolkits for additional SAM 2 use cases. + +### Semi-supervised VOS inference + +The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. + +After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. +```bash +python ./tools/vos_inference.py \ + --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ + --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ + --base_video_dir /path-to-davis-2017/JPEGImages/480p \ + --input_mask_dir /path-to-davis-2017/Annotations/480p \ + --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ + --output_mask_dir ./outputs/davis_2017_pred_pngs +``` +(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) + +To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. +```bash +python ./tools/vos_inference.py \ + --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ + --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ + --base_video_dir /path-to-sav-val/JPEGImages_24fps \ + --input_mask_dir /path-to-sav-val/Annotations_6fps \ + --video_list_file /path-to-sav-val/sav_val.txt \ + --per_obj_png_file \ + --output_mask_dir ./outputs/sav_val_pred_pngs +``` +(replace `/path-to-sav-val` with the path to SA-V val) + +Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. + +Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. diff --git a/third_party/sam2/tools/vos_inference.py b/third_party/sam2/tools/vos_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3e8c6740541f342cfbbe0fa8ad80e47caf4ac9 --- /dev/null +++ b/third_party/sam2/tools/vos_inference.py @@ -0,0 +1,507 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from collections import defaultdict + +import numpy as np +import torch +from PIL import Image +from sam2.build_sam import build_sam2_video_predictor + + +# the PNG palette for DAVIS 2017 dataset +DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0" + + +def load_ann_png(path): + """Load a PNG file as a mask and its palette.""" + mask = Image.open(path) + palette = mask.getpalette() + mask = np.array(mask).astype(np.uint8) + return mask, palette + + +def save_ann_png(path, mask, palette): + """Save a mask as a PNG file with the given palette.""" + assert mask.dtype == np.uint8 + assert mask.ndim == 2 + output_mask = Image.fromarray(mask) + output_mask.putpalette(palette) + output_mask.save(path) + + +def get_per_obj_mask(mask): + """Split a mask into per-object masks.""" + object_ids = np.unique(mask) + object_ids = object_ids[object_ids > 0].tolist() + per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids} + return per_obj_mask + + +def put_per_obj_mask(per_obj_mask, height, width): + """Combine per-object masks into a single mask.""" + mask = np.zeros((height, width), dtype=np.uint8) + object_ids = sorted(per_obj_mask)[::-1] + for object_id in object_ids: + object_mask = per_obj_mask[object_id] + object_mask = object_mask.reshape(height, width) + mask[object_mask] = object_id + return mask + + +def load_masks_from_dir( + input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False +): + """Load masks from a directory as a dict of per-object masks.""" + if not per_obj_png_file: + input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png") + if allow_missing and not os.path.exists(input_mask_path): + return {}, None + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask = get_per_obj_mask(input_mask) + else: + per_obj_input_mask = {} + input_palette = None + # each object is a directory in "{object_id:%03d}" format + for object_name in os.listdir(os.path.join(input_mask_dir, video_name)): + object_id = int(object_name) + input_mask_path = os.path.join( + input_mask_dir, video_name, object_name, f"{frame_name}.png" + ) + if allow_missing and not os.path.exists(input_mask_path): + continue + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask[object_id] = input_mask > 0 + + return per_obj_input_mask, input_palette + + +def save_masks_to_dir( + output_mask_dir, + video_name, + frame_name, + per_obj_output_mask, + height, + width, + per_obj_png_file, + output_palette, +): + """Save masks to a directory as PNG files.""" + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + if not per_obj_png_file: + output_mask = put_per_obj_mask(per_obj_output_mask, height, width) + output_mask_path = os.path.join( + output_mask_dir, video_name, f"{frame_name}.png" + ) + save_ann_png(output_mask_path, output_mask, output_palette) + else: + for object_id, object_mask in per_obj_output_mask.items(): + object_name = f"{object_id:03d}" + os.makedirs( + os.path.join(output_mask_dir, video_name, object_name), + exist_ok=True, + ) + output_mask = object_mask.reshape(height, width).astype(np.uint8) + output_mask_path = os.path.join( + output_mask_dir, video_name, object_name, f"{frame_name}.png" + ) + save_ann_png(output_mask_path, output_mask, output_palette) + + +@torch.inference_mode() +@torch.autocast(device_type="cuda", dtype=torch.bfloat16) +def vos_inference( + predictor, + base_video_dir, + input_mask_dir, + output_mask_dir, + video_name, + score_thresh=0.0, + use_all_masks=False, + per_obj_png_file=False, +): + """Run VOS inference on a single video with the given predictor.""" + # load the video frames and initialize the inference state on this video + video_dir = os.path.join(base_video_dir, video_name) + frame_names = [ + os.path.splitext(p)[0] + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + inference_state = predictor.init_state( + video_path=video_dir, async_loading_frames=False + ) + height = inference_state["video_height"] + width = inference_state["video_width"] + input_palette = None + + # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks) + if not use_all_masks: + # use only the first video's ground-truth mask as the input mask + input_frame_inds = [0] + else: + # use all mask files available in the input_mask_dir as the input masks + if not per_obj_png_file: + input_frame_inds = [ + idx + for idx, name in enumerate(frame_names) + if os.path.exists( + os.path.join(input_mask_dir, video_name, f"{name}.png") + ) + ] + else: + input_frame_inds = [ + idx + for object_name in os.listdir(os.path.join(input_mask_dir, video_name)) + for idx, name in enumerate(frame_names) + if os.path.exists( + os.path.join(input_mask_dir, video_name, object_name, f"{name}.png") + ) + ] + # check and make sure we got at least one input frame + if len(input_frame_inds) == 0: + raise RuntimeError( + f"In {video_name=}, got no input masks in {input_mask_dir=}. " + "Please make sure the input masks are available in the correct format." + ) + input_frame_inds = sorted(set(input_frame_inds)) + + # add those input masks to SAM 2 inference state before propagation + object_ids_set = None + for input_frame_idx in input_frame_inds: + try: + per_obj_input_mask, input_palette = load_masks_from_dir( + input_mask_dir=input_mask_dir, + video_name=video_name, + frame_name=frame_names[input_frame_idx], + per_obj_png_file=per_obj_png_file, + ) + except FileNotFoundError as e: + raise RuntimeError( + f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) from e + # get the list of object ids to track from the first input frame + if object_ids_set is None: + object_ids_set = set(per_obj_input_mask) + for object_id, object_mask in per_obj_input_mask.items(): + # check and make sure no new object ids appear only in later frames + if object_id not in object_ids_set: + raise RuntimeError( + f"In {video_name=}, got a new {object_id=} appearing only in a " + f"later {input_frame_idx=} (but not appearing in the first frame). " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) + predictor.add_new_mask( + inference_state=inference_state, + frame_idx=input_frame_idx, + obj_id=object_id, + mask=object_mask, + ) + + # check and make sure we have at least one object to track + if object_ids_set is None or len(object_ids_set) == 0: + raise RuntimeError( + f"In {video_name=}, got no object ids on {input_frame_inds=}. " + "Please add the `--track_object_appearing_later_in_video` flag " + "for VOS datasets that don't have all objects to track appearing " + "in the first frame (such as LVOS or YouTube-VOS)." + ) + # run propagation throughout the video and collect the results in a dict + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + output_palette = input_palette or DAVIS_PALETTE + video_segments = {} # video_segments contains the per-frame segmentation results + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state + ): + per_obj_output_mask = { + out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + video_segments[out_frame_idx] = per_obj_output_mask + + # write the output masks as palette PNG files to output_mask_dir + for out_frame_idx, per_obj_output_mask in video_segments.items(): + save_masks_to_dir( + output_mask_dir=output_mask_dir, + video_name=video_name, + frame_name=frame_names[out_frame_idx], + per_obj_output_mask=per_obj_output_mask, + height=height, + width=width, + per_obj_png_file=per_obj_png_file, + output_palette=output_palette, + ) + + +@torch.inference_mode() +@torch.autocast(device_type="cuda", dtype=torch.bfloat16) +def vos_separate_inference_per_object( + predictor, + base_video_dir, + input_mask_dir, + output_mask_dir, + video_name, + score_thresh=0.0, + use_all_masks=False, + per_obj_png_file=False, +): + """ + Run VOS inference on a single video with the given predictor. + + Unlike `vos_inference`, this function run inference separately for each object + in a video, which could be applied to datasets like LVOS or YouTube-VOS that + don't have all objects to track appearing in the first frame (i.e. some objects + might appear only later in the video). + """ + # load the video frames and initialize the inference state on this video + video_dir = os.path.join(base_video_dir, video_name) + frame_names = [ + os.path.splitext(p)[0] + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + inference_state = predictor.init_state( + video_path=video_dir, async_loading_frames=False + ) + height = inference_state["video_height"] + width = inference_state["video_width"] + input_palette = None + + # collect all the object ids and their input masks + inputs_per_object = defaultdict(dict) + for idx, name in enumerate(frame_names): + if per_obj_png_file or os.path.exists( + os.path.join(input_mask_dir, video_name, f"{name}.png") + ): + per_obj_input_mask, input_palette = load_masks_from_dir( + input_mask_dir=input_mask_dir, + video_name=video_name, + frame_name=frame_names[idx], + per_obj_png_file=per_obj_png_file, + allow_missing=True, + ) + for object_id, object_mask in per_obj_input_mask.items(): + # skip empty masks + if not np.any(object_mask): + continue + # if `use_all_masks=False`, we only use the first mask for each object + if len(inputs_per_object[object_id]) > 0 and not use_all_masks: + continue + print(f"adding mask from frame {idx} as input for {object_id=}") + inputs_per_object[object_id][idx] = object_mask + + # run inference separately for each object in the video + object_ids = sorted(inputs_per_object) + output_scores_per_object = defaultdict(dict) + for object_id in object_ids: + # add those input masks to SAM 2 inference state before propagation + input_frame_inds = sorted(inputs_per_object[object_id]) + predictor.reset_state(inference_state) + for input_frame_idx in input_frame_inds: + predictor.add_new_mask( + inference_state=inference_state, + frame_idx=input_frame_idx, + obj_id=object_id, + mask=inputs_per_object[object_id][input_frame_idx], + ) + + # run propagation throughout the video and collect the results in a dict + for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video( + inference_state, + start_frame_idx=min(input_frame_inds), + reverse=False, + ): + obj_scores = out_mask_logits.cpu().numpy() + output_scores_per_object[object_id][out_frame_idx] = obj_scores + + # post-processing: consolidate the per-object scores into per-frame masks + os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) + output_palette = input_palette or DAVIS_PALETTE + video_segments = {} # video_segments contains the per-frame segmentation results + for frame_idx in range(len(frame_names)): + scores = torch.full( + size=(len(object_ids), 1, height, width), + fill_value=-1024.0, + dtype=torch.float32, + ) + for i, object_id in enumerate(object_ids): + if frame_idx in output_scores_per_object[object_id]: + scores[i] = torch.from_numpy( + output_scores_per_object[object_id][frame_idx] + ) + + if not per_obj_png_file: + scores = predictor._apply_non_overlapping_constraints(scores) + per_obj_output_mask = { + object_id: (scores[i] > score_thresh).cpu().numpy() + for i, object_id in enumerate(object_ids) + } + video_segments[frame_idx] = per_obj_output_mask + + # write the output masks as palette PNG files to output_mask_dir + for frame_idx, per_obj_output_mask in video_segments.items(): + save_masks_to_dir( + output_mask_dir=output_mask_dir, + video_name=video_name, + frame_name=frame_names[frame_idx], + per_obj_output_mask=per_obj_output_mask, + height=height, + width=width, + per_obj_png_file=per_obj_png_file, + output_palette=output_palette, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--sam2_cfg", + type=str, + default="configs/sam2.1/sam2.1_hiera_b+.yaml", + help="SAM 2 model configuration file", + ) + parser.add_argument( + "--sam2_checkpoint", + type=str, + default="./checkpoints/sam2.1_hiera_base_plus.pt", + help="path to the SAM 2 model checkpoint", + ) + parser.add_argument( + "--base_video_dir", + type=str, + required=True, + help="directory containing videos (as JPEG files) to run VOS prediction on", + ) + parser.add_argument( + "--input_mask_dir", + type=str, + required=True, + help="directory containing input masks (as PNG files) of each video", + ) + parser.add_argument( + "--video_list_file", + type=str, + default=None, + help="text file containing the list of video names to run VOS prediction on", + ) + parser.add_argument( + "--output_mask_dir", + type=str, + required=True, + help="directory to save the output masks (as PNG files)", + ) + parser.add_argument( + "--score_thresh", + type=float, + default=0.0, + help="threshold for the output mask logits (default: 0.0)", + ) + parser.add_argument( + "--use_all_masks", + action="store_true", + help="whether to use all available PNG files in input_mask_dir " + "(default without this flag: just the first PNG file as input to the SAM 2 model; " + "usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)", + ) + parser.add_argument( + "--per_obj_png_file", + action="store_true", + help="whether use separate per-object PNG files for input and output masks " + "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; " + "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)", + ) + parser.add_argument( + "--apply_postprocessing", + action="store_true", + help="whether to apply postprocessing (e.g. hole-filling) to the output masks " + "(we don't apply such post-processing in the SAM 2 model evaluation)", + ) + parser.add_argument( + "--track_object_appearing_later_in_video", + action="store_true", + help="whether to track objects that appear later in the video (i.e. not on the first frame; " + "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", + ) + parser.add_argument( + "--use_vos_optimized_video_predictor", + action="store_true", + help="whether to use vos optimized video predictor with all modules compiled", + ) + args = parser.parse_args() + + # if we use per-object PNG files, they could possibly overlap in inputs and outputs + hydra_overrides_extra = [ + "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true") + ] + predictor = build_sam2_video_predictor( + config_file=args.sam2_cfg, + ckpt_path=args.sam2_checkpoint, + apply_postprocessing=args.apply_postprocessing, + hydra_overrides_extra=hydra_overrides_extra, + vos_optimized=args.use_vos_optimized_video_predictor, + ) + + if args.use_all_masks: + print("using all available masks in input_mask_dir as input to the SAM 2 model") + else: + print( + "using only the first frame's mask in input_mask_dir as input to the SAM 2 model" + ) + # if a video list file is provided, read the video names from the file + # (otherwise, we use all subdirectories in base_video_dir) + if args.video_list_file is not None: + with open(args.video_list_file, "r") as f: + video_names = [v.strip() for v in f.readlines()] + else: + video_names = [ + p + for p in os.listdir(args.base_video_dir) + if os.path.isdir(os.path.join(args.base_video_dir, p)) + ] + print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}") + + for n_video, video_name in enumerate(video_names): + print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}") + if not args.track_object_appearing_later_in_video: + vos_inference( + predictor=predictor, + base_video_dir=args.base_video_dir, + input_mask_dir=args.input_mask_dir, + output_mask_dir=args.output_mask_dir, + video_name=video_name, + score_thresh=args.score_thresh, + use_all_masks=args.use_all_masks, + per_obj_png_file=args.per_obj_png_file, + ) + else: + vos_separate_inference_per_object( + predictor=predictor, + base_video_dir=args.base_video_dir, + input_mask_dir=args.input_mask_dir, + output_mask_dir=args.output_mask_dir, + video_name=video_name, + score_thresh=args.score_thresh, + use_all_masks=args.use_all_masks, + per_obj_png_file=args.per_obj_png_file, + ) + + print( + f"completed VOS prediction on {len(video_names)} videos -- " + f"output masks saved to {args.output_mask_dir}" + ) + + +if __name__ == "__main__": + main() diff --git a/third_party/sam2/training/README.md b/third_party/sam2/training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b0c829d49d051d8f72e7bef959e33e6f0329c94d --- /dev/null +++ b/third_party/sam2/training/README.md @@ -0,0 +1,116 @@ +# Training Code for SAM 2 + +This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos. +The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both). + +## Structure + +The training code is organized into the following subfolders: + +* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms. +* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling). +* `utils`: This folder contains training utils such as loggers and distributed training utils. +* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training. +* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training. +* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers. +* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop. +* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h` + +## Getting Started + +To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets. + +#### Requirements: +- We assume training on A100 GPUs with **80 GB** of memory. +- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download). + +#### Steps to fine-tune on MOSE: +- Install the packages required for training by running `pip install -e ".[dev]"`. +- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`. + ```yaml + dataset: + # PATHS to Dataset + img_folder: null # PATH to MOSE JPEGImages folder + gt_folder: null # PATH to MOSE Annotations folder + file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training + ``` +- To fine-tune the base model on MOSE using 8 GPUs, run + + ```python + python training/train.py \ + -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ + --use-cluster 0 \ + --num-gpus 8 + ``` + + We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running + + ```python + python training/train.py \ + -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ + --use-cluster 1 \ + --num-gpus 8 \ + --num-nodes 2 + --partition $PARTITION \ + --qos $QOS \ + --account $ACCOUNT + ``` + where partition, qos, and account are optional and depend on your SLURM configuration. + By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows: + + ```yaml + experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} + ``` + The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4. + + + After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)). +## Training on images and videos +The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets: + +```yaml +data: + train: + _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset + phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases + batch_sizes: # List of batch sizes corresponding to each dataset + - ${bs1} # Batch size of dataset 1 + - ${bs2} # Batch size of dataset 2 + datasets: + # SA1B as an example of an image dataset + - _target_: training.dataset.vos_dataset.VOSDataset + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.SA1BRawDataset + img_folder: ${path_to_img_folder} + gt_folder: ${path_to_gt_folder} + file_list_txt: ${path_to_train_filelist} # Optional + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: 1 + max_num_objects: ${max_num_objects_per_image} + transforms: ${image_transforms} + # SA-V as an example of a video dataset + - _target_: training.dataset.vos_dataset.VOSDataset + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.JSONRawDataset + img_folder: ${path_to_img_folder} + gt_folder: ${path_to_gt_folder} + file_list_txt: ${path_to_train_filelist} # Optional + ann_every: 4 + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: 8 # Number of frames per video + max_num_objects: ${max_num_objects_per_video} + reverse_time_prob: ${reverse_time_prob} # probability to reverse video + transforms: ${video_transforms} + shuffle: True + num_workers: ${num_train_workers} + pin_memory: True + drop_last: True + collate_fn: + _target_: training.utils.data_utils.collate_fn + _partial_: true + dict_key: all +``` diff --git a/third_party/sam2/training/__init__.py b/third_party/sam2/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/training/assets/MOSE_sample_train_list.txt b/third_party/sam2/training/assets/MOSE_sample_train_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..28b22e3170f63de0fba3c77ef999f958cd6c48ff --- /dev/null +++ b/third_party/sam2/training/assets/MOSE_sample_train_list.txt @@ -0,0 +1,1246 @@ +28191f94 +662487fe +80906bf9 +7e704f2e +efa25913 +b6f03bd9 +6834d249 +5a723c30 +07779415 +4ce088c6 +199995b5 +54273925 +4fa342f5 +110da3cf +65856fa0 +46705bb3 +d869a3cf +555aa049 +8f01fb2c +37b07a28 +5e80b3dd +ba0e4dd4 +6f5144b6 +acec8407 +93723f88 +c7c7528c +97f58761 +e71f9faa +e64c13dc +8830d59d +0e4aeed9 +63437cf3 +95215aa1 +255f86ef +dc54aab2 +327cd258 +198021ad +c690220c +d25ff89d +7875b874 +4fa6d325 +9fc933f6 +4d8baafe +55ae6921 +6a3bc149 +89f8163f +2d65d2ac +dba172b1 +a14de179 +4017d1b3 +52ddf44c +3ba93641 +34a5f964 +da7dee28 +872b76de +1dc12eca +265a69f4 +86a2b59f +51e5ca25 +ddf80bcd +6786602e +4fa28c89 +f56942e9 +2184bb93 +d883e976 +bfe1469e +bc4e7b11 +1c80acb0 +2b0e34d3 +56b9ce41 +15f0b0cd +cc5d0dd1 +1b7eada8 +7286b176 +0ab42ab1 +adb82dc9 +c060b1e6 +3da63bd5 +5488796e +d7066e20 +aab5ed11 +17f66311 +24df9789 +208fa934 +7ce2c865 +debe4249 +4c56bbea +149dbae2 +beb693c9 +49eb0315 +e7ad4717 +4e016d5a +95e24093 +07b5d86c +80701b6c +337dfa1e +b624a46e +3f849de8 +5db21df2 +47891b4c +a966d7fd +013103f6 +da5e4bc5 +ba9ea03d +526195de +57f3a53e +b3aff7f8 +26048547 +bb7ee856 +aef0d049 +e35a8262 +57ad022e +f45d3823 +e5e9eb29 +39cc637e +a4fc4f17 +dd5a4739 +bbe97d18 +33602f6b +9061dac9 +23454d80 +a20baeec +794f01d4 +02de2f2a +055fca57 +a69df343 +e307510e +d07ad1be +1fc5e086 +db6533a5 +fe9706b7 +87e32230 +8ba58e4c +561f6380 +2ab9ba0f +86571569 +756cc6c9 +aa185af5 +c6d7f94b +7f54c579 +71f4b40e +4190c83a +fef0aba4 +2f7c71bb +e4b6f2ef +76adaeea +11cdeb64 +733f2a02 +e50dbddb +f643141f +d2e75e95 +84559bc3 +7ade3068 +e69db797 +0b787263 +57895315 +d7969c29 +62529cd4 +203733e7 +48fd97a6 +723fd024 +849f0efb +aafea009 +dd4eb8f1 +d18554ae +f3c0f0cf +90fe55b9 +b0ffaf3b +e79ecd47 +d670ce7b +56a5643a +90ff1d09 +1fb378d9 +57014c7d +994ed763 +5bc7ea74 +e99bd793 +cbb66185 +5f3fcff6 +05ed1023 +85efa9e3 +652929ce +905d8740 +a6fcde01 +0fdf67f7 +a5cf4c8d +e1c48bdd +782551f7 +6acd353f +c30641cf +81d12756 +51befc31 +9d5ab5ca +d262b7e4 +2cd705a9 +f7360199 +d3f3bf9d +028f6f64 +94767cb4 +3a739934 +72433603 +ec66879d +6149becc +5845c157 +c5082b3c +f89b54d0 +f3ada126 +409dcb8a +4411fdee +eb93ed20 +9cb1ba0e +b8e1ec26 +7edd8b4f +5e9412c0 +2744f35a +dafeb75e +f3f072f2 +6f1df574 +5a064706 +89c76ac4 +a6adef89 +76303516 +dbd67417 +a53ef3fa +10552818 +ac7deb19 +2d403c59 +55c157f1 +214aeac3 +a9f5e251 +d7807996 +d1dba33b +1367e367 +44476e77 +0644075b +eda37457 +f2de4198 +9a4ce701 +46e00caf +2ae75f99 +cd49fb99 +4e4483e7 +a0669957 +a6f0d882 +9ce1d54a +1fc2314b +21f363b3 +32ecef67 +70bcaf68 +115348f9 +60827ada +a218e951 +6d30d5ac +6da17988 +f22c39ce +5825f0e0 +f415f9ad +0d4feda2 +832fc243 +414ca58b +a92390a0 +ddd383cc +43dc67f7 +962ae0e2 +6dd74e7b +2bcd6c3b +b394847f +637fd121 +d46e771b +f6bfc699 +63f138de +932ad0a6 +2080824a +52fa9174 +843d3bf7 +f3431885 +5c20c48a +134a2ab0 +2ea465de +f6786ab5 +2bf49664 +a49ce97b +6a50e93a +a7c21e95 +616ad8ec +0a8d7b41 +b0c90527 +2d893fb7 +19310598 +7744dc51 +4539b907 +9d299f60 +e495537a +0b02886a +f4c4a2ca +e957b2b5 +e6f3bf07 +258944c8 +54364322 +ebb77f95 +0af03282 +cbdbc6c3 +494ecef0 +ee91f783 +9698f06e +11e16068 +b942ce0a +423a50e6 +fb16e746 +9c88ae45 +8620c024 +d3af3c85 +780a25de +e569a15f +c4f9f19e +1106f3a7 +d37e29a7 +e53611da +fdb2e432 +18ad3117 +6fcd426d +3bfa8379 +3b19c5c3 +ff1142df +cd182615 +b60ea255 +b3f5d019 +6dc5e55d +103166c7 +37af9ac1 +ad1881d1 +731149b3 +90e3338a +6aa0b6f2 +a25316a3 +dc8679e0 +571fb490 +80afed16 +983a551b +a58578e5 +2bc0bba4 +1143b3fe +fdd8dd49 +7fe2bf77 +890ef032 +8466eeb2 +c791ddbb +631b82bd +78bf9b51 +a99df45f +2bdb692f +e89b1501 +4e6aa1e8 +e5665030 +fe21fd5c +635577d5 +4414cd3a +03c99e83 +ff041cd1 +c33adbc2 +a988ec74 +576031e0 +03c21af7 +79b25f4b +bbc485d6 +d36d5a0d +efdab888 +b20e6781 +81fdc526 +e1c26a53 +7c6d3504 +52a04667 +f22e34d4 +bb936ead +13f0606c +d2abc61e +af509e8f +bea1c144 +e15e4de8 +e727099f +b30744df +ffb6a2e4 +0d31d3a6 +a23048fe +7d452630 +6c736334 +046ed4f4 +94f4c2aa +c290cfd3 +f7203226 +2fdae3c5 +7c78e351 +02b72b8d +2d22d3be +ba28d02e +197f6587 +43199a98 +b563b04f +9293b755 +9cef7489 +d156b96f +15e9161e +6d094cd5 +0d876a65 +c818d30a +8094b12b +a4a8e24b +14655f54 +11c14893 +8a48f62a +7f3d9c22 +d952481c +03e0f9b8 +28980657 +6a0b5563 +5879983c +37549a79 +4a7162bd +7a6aa1ef +0dc1b78c +f6dba17b +1dba51af +b2f4d608 +e2e6f421 +464066da +5d24e4ea +1e75004d +a02ed92c +673adbcc +c2a0c0fd +85addee5 +54b8f502 +f5d2d8d3 +a19507e1 +803e1756 +0d1fe009 +5968c2d8 +b926e1ad +a9162e14 +ae470d2b +bd731802 +68c879f2 +21fe05d9 +c1ed21d0 +831498e4 +cc45a7f2 +cb170015 +59750be4 +30d1cb6b +03e5f069 +106d33db +3f003746 +3e5ad020 +8bc5a91c +64b89eb5 +bfd28682 +f8687b9a +7bbf38ee +d6d92b30 +ceaa6c65 +677c8ed7 +dc33acf8 +cfd1de31 +e5be4781 +85585220 +5d2316f6 +dd3f4a07 +34535f5f +3ae0bc5d +f521e3c5 +74c2284f +12a42fd9 +61403519 +88cd32f3 +662a1846 +825a1944 +cf376cf1 +8465d99c +61a2e246 +62d44645 +103b3ca8 +c7e745ed +4ed71139 +230c2edf +529c6889 +9e509c0d +54b9dea2 +a8934c0d +29cffe2f +48017512 +c9f7f69d +ce691ee6 +21c89360 +3b97c07b +ebd82d35 +2895bb8b +7043c5c1 +85d694d7 +88fd7507 +18d8931e +aa718745 +89b671bb +0d8d30ae +26163977 +a6121689 +1589579d +159789c4 +f5ca8271 +fcc16740 +3158be0b +860fc1f7 +3f54a330 +82f24ce7 +069f6a2a +2fa9c523 +c9f1d87f +efe9cbca +8f969ea5 +4f5db794 +62c501f8 +2d3b0320 +c99637f0 +0f3b1fcb +6e4ee861 +e0d9aff0 +230ddb91 +e14d1f96 +c83aa6a1 +eabdf66a +6783a303 +81659eb2 +ce954bd7 +9a48c0c9 +0ab807b4 +f0617f71 +fe86f2f8 +61d80e22 +e4b6d2a0 +ac093040 +0e05fabe +d0b507c3 +3d828137 +c4fa0bab +f7783321 +ec27366a +404e4c58 +073baf48 +0f685e01 +b0e98fdd +b4891f7f +a46b7b77 +ee059f99 +3c87888e +8d23ddcc +2d8d7d35 +5680be79 +fc79c03e +20660b72 +53f67585 +90956534 +7e709e2d +dae93f5c +54b9dbba +cc41ba05 +1e207fe0 +a9c6abf2 +35e0ca09 +e3dcd186 +1b8bb699 +92162474 +cdad6812 +50b91533 +570215ac +6042d64a +b6e2c041 +08746283 +7a056996 +b8651773 +adf443e1 +6a6e0e3b +886ed981 +c1d57fea +43030c4c +7ebfbf57 +0770ad03 +e85301d5 +31ac3d98 +acaef45e +8f415dd1 +fe2dc281 +2c0b9d99 +8e24501e +911ec4ad +8036b58e +c3b350b9 +b6cadd11 +a3a80cf7 +88ab50cd +59c755a8 +1339321a +91b2f707 +97b0811e +1da33959 +31b09833 +c1a40349 +708098a9 +1f220f98 +999e07cb +0b5e5d29 +94c63453 +b826d642 +a598602d +4c83eab8 +2efd5e50 +6ec5da3a +9fcd95eb +9a2c6b5b +c205a718 +e638e950 +cb43141c +494dd91d +c4957274 +4975a81d +a1f4c54d +51e6fafa +514490e5 +b0d09e6a +c6726eb8 +06772c9a +5a65ffd7 +3657c62b +03012cfd +529df209 +f1c38e66 +ab417352 +118a067e +8957514f +22e8b380 +3b1a4616 +a4457543 +57c9f6e0 +e362c16b +0f809e41 +857e375e +9cff25e3 +d754fb65 +6ad44b86 +051052d8 +a4564b94 +f68507d0 +80a7cf7b +ad8cd1e0 +60b19cd3 +274fe944 +f06632aa +628a337b +92c96c05 +87fc565c +6f6e6c37 +228a0234 +6487110a +aa911a8e +40c47fa3 +9606508b +6ba9e61f +c8c1d5a9 +cf01df5b +9421b9ad +006e6b64 +1c28e081 +06273084 +8925e11b +b46c822b +00501424 +cfd946b2 +2e92a7dc +1c5f5bb6 +1d29944c +8248698e +19247506 +1eac1aff +ee9caa47 +4a41cbf8 +d97c9309 +4ca87c14 +9707f1e3 +8bb9a221 +6605e67d +95cf72d7 +1c6fb814 +033130b2 +4344808d +5f14e5d2 +a810399b +e325a6d4 +7014ddf4 +725d4bfb +790285e8 +1a6a731f +fbfb6e30 +0d4d88f6 +80ce18a4 +572495b7 +4b44dc50 +95dce33c +4a6fb202 +3142014e +a3c56751 +96b2a414 +c4aa176c +fd1e394f +93f0f509 +f494e9fa +bfa42a75 +db5319c7 +aa92e070 +81220a93 +e4a72496 +fc467bf1 +5397b01d +1dc0c9a0 +f6f8b4a6 +53dc7db4 +8ef303eb +62ca45c9 +e9d3465e +3784e3f6 +8c934e67 +5ba84e3f +30e41f1e +61cf0ec8 +e93e8f01 +fc6086dd +a95f0aea +33a04ef2 +6f295adb +d2aa8c66 +724cc810 +d8623d26 +8d0d641a +4bda7a76 +38030c69 +56199c41 +d2f4b9e2 +a7b8ac96 +64044df1 +fd1078cc +0165667b +16e1cca7 +915f0d9a +eeaaa67e +378430d5 +a84c60e6 +b4ae36cc +2a3a0571 +13e6df75 +aa348c45 +59d7a11d +68954daf +d6f883c6 +f28b429a +32dc49d4 +ccf14ee0 +7d512591 +9bdabdb2 +ed878d94 +54eda06d +132561ee +3c4b6736 +0367af42 +531c1c36 +843d8f25 +333bdbdc +c3c21268 +07b00746 +c7fe0584 +49fc9f2e +9ed4317a +d29991b4 +98b0033d +f0b922bf +89fe6899 +58264713 +2f49220a +6ff85ca5 +4b96b2c8 +a42f54f5 +aa425600 +22fdee40 +dde85a9d +3722f6fe +e7529cbc +5ae23f9f +cc32235b +730bc486 +b12701b7 +a96b3010 +16130bd3 +2c713560 +f7935d24 +a7eb6616 +0d6e7177 +100edaef +0442a954 +60f4fa43 +37bf7edf +76b18413 +ab0646a9 +c575434d +1e356390 +5416fbb7 +df7cf932 +269872de +9033b607 +c2e88575 +932542cd +23e046fb +3d08dadd +7999adc5 +ed81c485 +3bd7facd +1feae28e +8d72533b +6a8d35d6 +65308bdc +7f0b7662 +98290486 +fee3371f +c463c7e5 +faf7d852 +75c34dc5 +96a6722e +e5605136 +851bc5d9 +15c41c4b +6a39e104 +5fbff256 +0e7001dd +5411113f +3ea2f7f2 +242b74b1 +87727003 +ec6dd0e9 +980baf58 +9d0b7bf1 +9113c9d4 +5ebef6bd +a5f70ce7 +b0240233 +06ad78e0 +8745edd0 +d8e8d984 +ac32a655 +38568758 +d48c552d +0b27d5f7 +c65d0736 +800e3c14 +d37a5857 +bcebc660 +d3ab52cc +405e3ee7 +e33cddc9 +b0197182 +89fd5681 +9e192417 +8554c402 +aae923b8 +31af515d +75b26f88 +60471744 +460945aa +c0fe8e1a +1731babb +2e85e35d +f9c20062 +115da184 +ddfa88c7 +359003f8 +dfa99126 +bf04814f +f407a414 +e18723c4 +0a7a3629 +c07ab37e +1251a1c9 +4d09d22a +5984ed74 +34504f63 +ced51047 +08ff419c +d942e98c +2697f864 +3b671a61 +72a2f7e2 +48e7cafe +6adad2f7 +18840617 +1e44f47e +36cc4055 +8c494902 +2982de7a +6a428397 +c4a0ecfb +231d6945 +fe470104 +f93e1bd0 +bd18bc5a +7bd70d93 +8f81a0ee +db78e7a1 +7593caea +86d5b29b +5457b298 +0d967fd1 +62372d4c +68259db3 +f0944ea2 +7b017dbf +bcb6e338 +03692b14 +f7d36a47 +1ca2531a +6728528d +1fc0e6a8 +0ba9c5ad +a386eaa2 +b0c5459f +1d64aff3 +b97d4f1a +b3745d91 +c461003e +910bf878 +ae42601c +8d2ddeff +aaecaa39 +250b5034 +edb11192 +7bfe9b57 +6d533759 +51586b36 +a38d648a +8fdb48e5 +6075d6b0 +3588ea03 +bc844942 +398d41f5 +660e3b70 +0b99f522 +f169fd1b +7bfa2ab5 +ab461319 +25153e58 +002b4dce +a2df1bee +550a7357 +b604f2dd +2f477d05 +bdf9eb5a +857ddc6e +c8f0fd41 +6df96f15 +e147ab26 +788da8e8 +02221fb0 +d1d95c61 +a3f0cb28 +3a6e6ace +67c2909a +220382ab +eaed776d +aff08a61 +b99d1bd6 +9d9ae988 +34ccea00 +41dae436 +18513251 +ad57acd1 +67f110fc +3f09f5c9 +25ef7d43 +12a5d0d7 +3ff48b8b +26ed56e6 +c047a092 +bb8639e1 +8788747f +584838d4 +f8e5f837 +657242e8 +cb8eedf4 +74a917f1 +578f71da +c9b27125 +22e1f53c +f40145c2 +4795259b +3f313a2f +c9012bf6 +22167a50 +6e7f9437 +ef51a724 +356e0fcb +d3ea999d +08a5c662 +85aa3b0e +579fadec +7bc95dc2 +c097af8e +f01d8b9f +80fb79c6 +ea65e6b7 +29ff29f6 +9e1f739d +b7fb59c9 +e2160f17 +0be33bc1 +e96b9b04 +b1affe79 +c4f4b2e2 +f4c8ffb1 +6a009e50 +a8828854 +2786f841 +a64e724c +5f54d077 +7040385d +6e0f0ecc +f33d3c15 +8108b358 +46a502de +1e0fb02a +ddbdfa32 +e7b34ab6 +c9080ed1 +395224b3 +33f9ab47 +c245ecda +c28d81a9 +37303a3b +6380dd6f +2fb5a55b +83b7c53c +41c8d0d2 +3aab2d13 +dc7d21fb +86a88668 +37bb38fe +ab6413a8 +bbe585b2 +a0ca072a +9d5940d2 +ddb1d0b1 +a946317a +988b29a4 +89dc0432 +5df8490d +5e167efa +50a86faa +fe6a535a +a9f8b8b4 +6e2dce1b +d0696759 +c09da3b2 +f07dd347 +67408899 +406165ff +a4a9d03d +9b5f0f47 +5f3e8022 +1d7a23e0 +25af2eeb +82a3db34 +c9351029 +6c93d44c +f088ad1c +9ee59f51 +b5276b3f +ca74a924 +781af187 +fa3e0b85 +b898c99e +1ca51f06 +5a92a0c1 +138c81fe +d0722d0f +05a7d84d +e18f1dea +799a2d61 +8276e558 +f0ba8748 +ce733e8a +2f9d0911 +58f24fa4 +66a25278 +3135d31d +4b9223ee +bdd5e6b3 +ddbebec1 +8dbebbd9 +3020b38f +e607450d +724a5d1c +91b754c5 +2e85e790 +3a407bd9 +fd137178 +a304029b +4023fc77 +440d5072 +2eb73c7c +164a7305 +b33ade7c +277ad883 +b0f7e75c +74107936 +83924bdb +b72beb78 +86c01d64 +f6f441eb +23b9a3ea +80b73f1a +93c6411d +1e95ef5e +800b5eac +9519832a +ae043406 +b06a902e +1dbca5cc +571f88a1 +b1faf52b +45572497 +8d016cdb +f92cdae8 +316931f8 +f9884439 +e1b7f212 +e23c6392 +ccfae073 +5aa1efda +74f0687c +eaff3301 +b6520a94 +c5398714 +15e7e4d1 +0fc00006 +8cf49218 +3a8ddc0a +e7e2a0b9 +eec4c008 +8d73085e +77e246da +00e92ab4 +f76f6cf9 +19801183 +233406ef +b80e028c +342c0b2a +a2768c47 +99350a74 +adbd400b +f3978ade +b87a4f6c +fa95a6a2 +6dff20c9 +935b5ad8 +dbbbb401 +1b6472c1 +9c0e6331 +04ae7a6b +4c94e4f3 +90cb46cb +2831ecf5 +ff77a145 +79af6097 +ba61a719 +abcb7665 +7e87750e +c4c7bc5d +3a670b81 +3d9a7023 +82667d52 +a4587f62 +ca619b7f +7c5462f5 +bda5c60d +e6e48ac8 +405c6000 +7981f344 +f7375ab3 +bb467ff9 +cfc68a82 +e417a6d8 +1a6177c1 +7b75dace +b1af350d +484d48a3 +1f805416 +7416ab4e +1291276c +9e85179b +5a74660c +7e6d00df +01e3cec8 +ee2c0688 +f6de8226 +a217538c +b432c3ef +49e5ff4e +035359e5 +8ae8e7ed +2da12766 +cac39070 +115adda4 +1a2872dc +fac3378e +294e7bf8 +a1a4991f +c062f4d7 +72b2b77d +158062aa +9ae447a7 +a7b05677 +fdfd5d56 +eac1a9e6 +a5905593 +59992293 +84298fae +f708e55f +093d3d93 +75d26197 +924f5d88 +3184a7ec +b454fdbc +2d9101b8 +ae70fb7c +4385b2c4 +63b37343 +0b4b662c +2883ae72 +ffcab778 +0f96e2d7 +897066e3 +f23e98ad +797a7b7e +2fc476f9 diff --git a/third_party/sam2/training/assets/MOSE_sample_val_list.txt b/third_party/sam2/training/assets/MOSE_sample_val_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..9721028718245ff5297fdae59d35a7c89cb5f56a --- /dev/null +++ b/third_party/sam2/training/assets/MOSE_sample_val_list.txt @@ -0,0 +1,200 @@ +32e5d721 +5bad0bab +267bfd6c +0a43a414 +56c56ca9 +9a1146b3 +c6ad7aaf +78a1f4b1 +fc455e73 +072e7b3f +77ccb57d +a76ee415 +8cdcfc17 +5d518b42 +376dd830 +0e843fc8 +2af0e766 +2bd4e845 +de2f2a6a +ade9ee91 +001ca3cb +fc4c1c67 +8ef55579 +b84ce852 +4cc8528a +767ffaaa +112a2ef0 +a338c8aa +cbd144f5 +5ff72128 +86a949e2 +9f2323ac +1fab1d1c +75924351 +ef55817b +02deca50 +4d979d99 +4d65f873 +28470fa0 +0d1575fe +06ea172e +29a6ddc2 +797f1bec +780e7a99 +b9ed5b44 +02a236b4 +607d8ff5 +af5666b2 +0558d0ed +a938c6b2 +103df575 +77110e80 +739e5a07 +6763a576 +06ebc138 +ba4b3b09 +b35cc2f3 +4e0597a0 +5949ee84 +5348d547 +323c4236 +b3b51117 +55727ddd +ab2714f3 +d2878895 +c0734cb3 +94f7c53e +2a2745e5 +442ffb54 +3592425a +50ae03b0 +5f150435 +3067f9fa +9ffb2818 +adeaf5aa +31caacec +1cd99b86 +aa22f9d0 +8fa50320 +e6348d2c +42ff84a5 +8c8b7913 +c96adcbc +495be321 +db735509 +ee113fc4 +a678cdab +c409ca4d +68d2b259 +592b4dee +4e2b4dc7 +eb4d26e1 +2009a00f +bec5c89d +67191f24 +a3e85b4b +da7080cd +80d978e9 +36dcb93f +a41e8c44 +12fdc864 +46d140ea +657c9dd9 +a86f84ee +90c1c43d +33015509 +afc7664d +23df06e1 +291d4799 +0ab75563 +251bf059 +bcefdcc4 +ce9a2796 +94d3403a +8f2e04bc +f9cda066 +9dfa2cc5 +66924c91 +e765a09e +15654ee1 +48e0bd39 +ee095221 +2463609b +544d0d1f +51b8c2e1 +d321dde4 +4cb11a5f +d7058a0d +37af282a +fabae187 +7be91184 +181ec185 +2d16ceeb +b56be4b1 +6699eff0 +79acac96 +d61c4665 +0c13e1e7 +100f6ecf +71217dfc +82df0888 +4c42c747 +c9fdf703 +d2efeb4b +69ed9d14 +64914fb6 +255bedbc +4ea934d8 +a034feb2 +e4f4ddae +e36a3026 +c1489591 +111bb373 +e1d9fb32 +93e22d48 +c1ec4b26 +d9638e69 +60ab04c5 +cfe7773a +62132822 +2f5fb2a3 +7bdd197d +033333fd +130fcdbe +12e509c2 +67138c33 +6f90cc5f +4e3020fe +bbdd8bb7 +b399ccdb +fecd10d2 +2e0967f7 +f509054f +792c6ff7 +48e2afc5 +d904c048 +111e0a5c +b83024e2 +e6a7b79c +bdc5ccf7 +b8146d00 +9d394f1a +645b84f9 +95ab2d0f +e6f8a31d +b4f876fb +dc2c570d +3afd02d7 +5c80c82c +b1b32ddd +9f25fc61 +ba538072 +f8916fef +43c04ad2 +a658e949 +2861dd53 +f6e40aba +09d305d1 +aac33bff +8d9d4c08 diff --git a/third_party/sam2/training/dataset/__init__.py b/third_party/sam2/training/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/training/dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/training/dataset/sam2_datasets.py b/third_party/sam2/training/dataset/sam2_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..6deda056bea555fc07ace455ccc62c606a7b81c9 --- /dev/null +++ b/third_party/sam2/training/dataset/sam2_datasets.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Callable, Iterable, List, Optional, Sequence + +import torch + +from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset + +from torch.utils.data.distributed import DistributedSampler + + +class MixedDataLoader: + def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): + """ + Args: + dataloaders (List[DataLoader]): List of DataLoaders to be mixed. + mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from + + """ + assert len(dataloaders) == mixing_prob.shape[0] + self.dataloaders = dataloaders + self.mixing_prob = mixing_prob + # Iterator state + self._iter_dls = None + self._iter_mixing_prob = None + self.random_generator = torch.Generator() + + def __len__(self): + return sum([len(d) for d in self.dataloaders]) + + def __iter__(self): + # Synchronize dataloader seeds + self.random_generator.manual_seed(42) + self._iter_dls = [iter(loader) for loader in self.dataloaders] + self._iter_mixing_prob = self.mixing_prob.clone() + return self + + def __next__(self): + """ + Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. + """ + if self._iter_dls is None: + raise TypeError(f"{type(self).__name__} object is not an iterator") + + while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. + dataset_idx = self._iter_mixing_prob.multinomial( + 1, generator=self.random_generator + ).item() + try: + item = next(self._iter_dls[dataset_idx]) + return item + except StopIteration: + # No more iterations for this dataset, set it's mixing probability to zero and try again. + self._iter_mixing_prob[dataset_idx] = 0 + except Exception as e: + # log and raise any other unexpected error. + logging.error(e) + raise e + + # Exhausted all iterators + raise StopIteration + + +class TorchTrainMixedDataset: + def __init__( + self, + datasets: List[Dataset], + batch_sizes: List[int], + num_workers: int, + shuffle: bool, + pin_memory: bool, + drop_last: bool, + collate_fn: Optional[Callable] = None, + worker_init_fn: Optional[Callable] = None, + phases_per_epoch: int = 1, + dataset_prob: Optional[List[float]] = None, + ) -> None: + """ + Args: + datasets (List[Dataset]): List of Datasets to be mixed. + batch_sizes (List[int]): Batch sizes for each dataset in the list. + num_workers (int): Number of workers per dataloader. + shuffle (bool): Whether or not to shuffle data. + pin_memory (bool): If True, use pinned memory when loading tensors from disk. + drop_last (bool): Whether or not to drop the last batch of data. + collate_fn (Callable): Function to merge a list of samples into a mini-batch. + worker_init_fn (Callable): Function to init each dataloader worker. + phases_per_epoch (int): Number of phases per epoch. + dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 + """ + + self.datasets = datasets + self.batch_sizes = batch_sizes + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + self.collate_fn = collate_fn + self.worker_init_fn = worker_init_fn + assert len(self.datasets) > 0 + for dataset in self.datasets: + assert not isinstance(dataset, IterableDataset), "Not supported" + # `RepeatFactorWrapper` requires calling set_epoch first to get its length + self._set_dataset_epoch(dataset, 0) + self.phases_per_epoch = phases_per_epoch + self.chunks = [None] * len(datasets) + if dataset_prob is None: + # If not provided, assign each dataset a probability proportional to its length. + dataset_lens = [ + (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) + for d, bs in zip(datasets, batch_sizes) + ] + total_len = sum(dataset_lens) + dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) + else: + assert len(dataset_prob) == len(datasets) + dataset_prob = torch.tensor(dataset_prob) + + logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") + assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" + self.dataset_prob = dataset_prob + + def _set_dataset_epoch(self, dataset, epoch: int) -> None: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + def get_loader(self, epoch) -> Iterable: + dataloaders = [] + for d_idx, (dataset, batch_size) in enumerate( + zip(self.datasets, self.batch_sizes) + ): + if self.phases_per_epoch > 1: + # Major epoch that looops over entire dataset + # len(main_epoch) == phases_per_epoch * len(epoch) + main_epoch = epoch // self.phases_per_epoch + + # Phase with in the main epoch + local_phase = epoch % self.phases_per_epoch + + # Start of new data-epoch or job is resumed after preemtion. + if local_phase == 0 or self.chunks[d_idx] is None: + # set seed for dataset epoch + # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. + self._set_dataset_epoch(dataset, main_epoch) + + # Separate random generator for subset sampling + g = torch.Generator() + g.manual_seed(main_epoch) + self.chunks[d_idx] = torch.chunk( + torch.randperm(len(dataset), generator=g), + self.phases_per_epoch, + ) + + dataset = Subset(dataset, self.chunks[d_idx][local_phase]) + else: + self._set_dataset_epoch(dataset, epoch) + + sampler = DistributedSampler(dataset, shuffle=self.shuffle) + sampler.set_epoch(epoch) + + batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) + dataloaders.append( + DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + worker_init_fn=self.worker_init_fn, + ) + ) + return MixedDataLoader(dataloaders, self.dataset_prob) diff --git a/third_party/sam2/training/dataset/transforms.py b/third_party/sam2/training/dataset/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5c6512ac7fd9548273fb152a3b57ef75e4fc18 --- /dev/null +++ b/third_party/sam2/training/dataset/transforms.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging + +import random +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage + +from torchvision.transforms import InterpolationMode + +from training.utils.data_utils import VideoDatapoint + + +def hflip(datapoint, index): + + datapoint.frames[index].data = F.hflip(datapoint.frames[index].data) + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + obj.segment = F.hflip(obj.segment) + + return datapoint + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(datapoint, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + cur_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + size = get_size(cur_size, size, max_size) + + old_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + if v2: + datapoint.frames[index].data = Fv2.resize( + datapoint.frames[index].data, size, antialias=True + ) + else: + datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size) + + new_size = ( + datapoint.frames[index].data.size()[-2:][::-1] + if v2 + else datapoint.frames[index].data.size + ) + + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + h, w = size + datapoint.frames[index].size = (h, w) + return datapoint + + +def pad(datapoint, index, padding, v2=False): + old_h, old_w = datapoint.frames[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + datapoint.frames[index].data = F.pad( + datapoint.frames[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + # left, top, right, bottom + datapoint.frames[index].data = F.pad( + datapoint.frames[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + datapoint.frames[index].size = (h, w) + + for obj in datapoint.frames[index].objects: + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = Fv2.pad(obj.segment, tuple(padding)) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + return datapoint + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(datapoint.frames)): + datapoint = hflip(datapoint, i) + return datapoint + for i in range(len(datapoint.frames)): + if random.random() < self.p: + datapoint = hflip(datapoint, i) + return datapoint + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(datapoint.frames)): + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + for i in range(len(datapoint.frames)): + size = random.choice(self.sizes) + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for img in datapoint.frames: + if self.v2: + img.data = Fv2.to_image_tensor(img.data) + else: + img.data = F.to_tensor(img.data) + return datapoint + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for img in datapoint.frames: + if self.v2: + img.data = Fv2.convert_image_dtype(img.data, torch.float32) + img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + else: + img.data = F.normalize(img.data, mean=self.mean, std=self.std) + + return datapoint + + +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, datapoint, **kwargs): + for t in self.transforms: + datapoint = t(datapoint, **kwargs) + return datapoint + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img in datapoint.frames: + img.data = self.Grayscale(img.data) + return datapoint + for img in datapoint.frames: + if random.random() < self.p: + img.data = self.Grayscale(img.data) + return datapoint + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in datapoint.frames: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img.data = F.adjust_brightness(img.data, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img.data = F.adjust_contrast(img.data, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img.data = F.adjust_saturation(img.data, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img.data = F.adjust_hue(img.data, hue_factor) + return datapoint + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, datapoint: VideoDatapoint, **kwargs): + for _tentative in range(self.num_tentatives): + res = self.transform_datapoint(datapoint) + if res is not None: + return res + + if self.log_warning: + logging.warning( + f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + ) + return datapoint + + def transform_datapoint(self, datapoint: VideoDatapoint): + _, height, width = F.get_dimensions(datapoint.frames[0].data) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(datapoint.frames): + this_masks = [ + obj.segment.unsqueeze(0) if obj.segment is not None else None + for obj in img.objects + ] + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + transformed_bboxes, transformed_masks = [], [] + for i in range(len(img.objects)): + if this_masks[i] is None: + transformed_masks.append(None) + # Dummy bbox for a dummy target + transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]])) + else: + transformed_mask = F.affine( + this_masks[i], + *affine_params, + interpolation=InterpolationMode.NEAREST, + fill=0.0, + ) + if img_idx == 0 and transformed_mask.max() == 0: + # We are dealing with a video and the object is not visible in the first frame + # Return the datapoint without transformation + return None + transformed_masks.append(transformed_mask.squeeze()) + + for i in range(len(img.objects)): + img.objects[i].segment = transformed_masks[i] + + img.data = F.affine( + img.data, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + return datapoint + + +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.frames[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.frames[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + for obj in datapoint.frames[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, datapoint, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.frames)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint diff --git a/third_party/sam2/training/dataset/utils.py b/third_party/sam2/training/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a658df234c3dcf74404f844b5be793b0545485ed --- /dev/null +++ b/third_party/sam2/training/dataset/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" + +from typing import Iterable + +import torch +from torch.utils.data import ( + ConcatDataset as TorchConcatDataset, + Dataset, + Subset as TorchSubset, +) + + +class ConcatDataset(TorchConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__(datasets) + + self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) + + def set_epoch(self, epoch: int): + for dataset in self.datasets: + if hasattr(dataset, "epoch"): + dataset.epoch = epoch + if hasattr(dataset, "set_epoch"): + dataset.set_epoch(epoch) + + +class Subset(TorchSubset): + def __init__(self, dataset, indices) -> None: + super(Subset, self).__init__(dataset, indices) + + self.repeat_factors = dataset.repeat_factors[indices] + assert len(indices) == len(self.repeat_factors) + + +# Adapted from Detectron2 +class RepeatFactorWrapper(Dataset): + """ + Thin wrapper around a dataset to implement repeat factor sampling. + The underlying dataset must have a repeat_factors member to indicate the per-image factor. + Set it to uniformly ones to disable repeat factor sampling + """ + + def __init__(self, dataset, seed: int = 0): + self.dataset = dataset + self.epoch_ids = None + self._seed = seed + + # Split into whole number (_int_part) and fractional (_frac_part) parts. + self._int_part = torch.trunc(dataset.repeat_factors) + self._frac_part = dataset.repeat_factors - self._int_part + + def _get_epoch_indices(self, generator): + """ + Create a list of dataset indices (with repeats) to use for one epoch. + + Args: + generator (torch.Generator): pseudo random number generator used for + stochastic rounding. + + Returns: + torch.Tensor: list of dataset indices to use in one epoch. Each index + is repeated based on its calculated repeat factor. + """ + # Since repeat factors are fractional, we use stochastic rounding so + # that the target repeat factor is achieved in expectation over the + # course of training + rands = torch.rand(len(self._frac_part), generator=generator) + rep_factors = self._int_part + (rands < self._frac_part).float() + # Construct a list of indices in which we repeat images as specified + indices = [] + for dataset_index, rep_factor in enumerate(rep_factors): + indices.extend([dataset_index] * int(rep_factor.item())) + return torch.tensor(indices, dtype=torch.int64) + + def __len__(self): + if self.epoch_ids is None: + # Here we raise an error instead of returning original len(self.dataset) avoid + # accidentally using unwrapped length. Otherwise it's error-prone since the + # length changes to `len(self.epoch_ids)`changes after set_epoch is called. + raise RuntimeError("please call set_epoch first to get wrapped length") + # return len(self.dataset) + + return len(self.epoch_ids) + + def set_epoch(self, epoch: int): + g = torch.Generator() + g.manual_seed(self._seed + epoch) + self.epoch_ids = self._get_epoch_indices(g) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __getitem__(self, idx): + if self.epoch_ids is None: + raise RuntimeError( + "Repeat ids haven't been computed. Did you forget to call set_epoch?" + ) + + return self.dataset[self.epoch_ids[idx]] diff --git a/third_party/sam2/training/dataset/vos_dataset.py b/third_party/sam2/training/dataset/vos_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e9d39fe184cf0d86fbf22b5385dc05988cab83 --- /dev/null +++ b/third_party/sam2/training/dataset/vos_dataset.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import random +from copy import deepcopy + +import numpy as np + +import torch +from iopath.common.file_io import g_pathmgr +from PIL import Image as PILImage +from torchvision.datasets.vision import VisionDataset + +from training.dataset.vos_raw_dataset import VOSRawDataset +from training.dataset.vos_sampler import VOSSampler +from training.dataset.vos_segment_loader import JSONSegmentLoader + +from training.utils.data_utils import Frame, Object, VideoDatapoint + +MAX_RETRIES = 100 + + +class VOSDataset(VisionDataset): + def __init__( + self, + transforms, + training: bool, + video_dataset: VOSRawDataset, + sampler: VOSSampler, + multiplier: int, + always_target=True, + target_segments_available=True, + ): + self._transforms = transforms + self.training = training + self.video_dataset = video_dataset + self.sampler = sampler + + self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) + self.repeat_factors *= multiplier + print(f"Raw dataset length = {len(self.video_dataset)}") + + self.curr_epoch = 0 # Used in case data loader behavior changes across epochs + self.always_target = always_target + self.target_segments_available = target_segments_available + + def _get_datapoint(self, idx): + + for retry in range(MAX_RETRIES): + try: + if isinstance(idx, torch.Tensor): + idx = idx.item() + # sample a video + video, segment_loader = self.video_dataset.get_video(idx) + # sample frames and object indices to be used in a datapoint + sampled_frms_and_objs = self.sampler.sample( + video, segment_loader, epoch=self.curr_epoch + ) + break # Succesfully loaded video + except Exception as e: + if self.training: + logging.warning( + f"Loading failed (id={idx}); Retry {retry} with exception: {e}" + ) + idx = random.randrange(0, len(self.video_dataset)) + else: + # Shouldn't fail to load a val video + raise e + + datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) + for transform in self._transforms: + datapoint = transform(datapoint, epoch=self.curr_epoch) + return datapoint + + def construct(self, video, sampled_frms_and_objs, segment_loader): + """ + Constructs a VideoDatapoint sample to pass to transforms + """ + sampled_frames = sampled_frms_and_objs.frames + sampled_object_ids = sampled_frms_and_objs.object_ids + + images = [] + rgb_images = load_images(sampled_frames) + # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) + for frame_idx, frame in enumerate(sampled_frames): + w, h = rgb_images[frame_idx].size + images.append( + Frame( + data=rgb_images[frame_idx], + objects=[], + ) + ) + # We load the gt segments associated with the current frame + if isinstance(segment_loader, JSONSegmentLoader): + segments = segment_loader.load( + frame.frame_idx, obj_ids=sampled_object_ids + ) + else: + segments = segment_loader.load(frame.frame_idx) + for obj_id in sampled_object_ids: + # Extract the segment + if obj_id in segments: + assert ( + segments[obj_id] is not None + ), "None targets are not supported" + # segment is uint8 and remains uint8 throughout the transforms + segment = segments[obj_id].to(torch.uint8) + else: + # There is no target, we either use a zero mask target or drop this object + if not self.always_target: + continue + segment = torch.zeros(h, w, dtype=torch.uint8) + + images[frame_idx].objects.append( + Object( + object_id=obj_id, + frame_index=frame.frame_idx, + segment=segment, + ) + ) + return VideoDatapoint( + frames=images, + video_id=video.video_id, + size=(h, w), + ) + + def __getitem__(self, idx): + return self._get_datapoint(idx) + + def __len__(self): + return len(self.video_dataset) + + +def load_images(frames): + all_images = [] + cache = {} + for frame in frames: + if frame.data is None: + # Load the frame rgb data from file + path = frame.image_path + if path in cache: + all_images.append(deepcopy(all_images[cache[path]])) + continue + with g_pathmgr.open(path, "rb") as fopen: + all_images.append(PILImage.open(fopen).convert("RGB")) + cache[path] = len(all_images) - 1 + else: + # The frame rgb data has already been loaded + # Convert it to a PILImage + all_images.append(tensor_2_PIL(frame.data)) + + return all_images + + +def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: + data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 + data = data.astype(np.uint8) + return PILImage.fromarray(data) diff --git a/third_party/sam2/training/dataset/vos_raw_dataset.py b/third_party/sam2/training/dataset/vos_raw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..44fe893717a3e3bd85b043baa33d349b52b4b34e --- /dev/null +++ b/third_party/sam2/training/dataset/vos_raw_dataset.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import logging +import os +from dataclasses import dataclass + +from typing import List, Optional + +import pandas as pd + +import torch + +from iopath.common.file_io import g_pathmgr + +from omegaconf.listconfig import ListConfig + +from training.dataset.vos_segment_loader import ( + JSONSegmentLoader, + MultiplePNGSegmentLoader, + PalettisedPNGSegmentLoader, + SA1BSegmentLoader, +) + + +@dataclass +class VOSFrame: + frame_idx: int + image_path: str + data: Optional[torch.Tensor] = None + is_conditioning_only: Optional[bool] = False + + +@dataclass +class VOSVideo: + video_name: str + video_id: int + frames: List[VOSFrame] + + def __len__(self): + return len(self.frames) + + +class VOSRawDataset: + def __init__(self): + pass + + def get_video(self, idx): + raise NotImplementedError() + + +class PNGRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + is_palette=True, + single_object_mode=False, + truncate_video=-1, + frames_sampling_mult=False, + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.sample_rate = sample_rate + self.is_palette = is_palette + self.single_object_mode = single_object_mode + self.truncate_video = truncate_video + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + if self.single_object_mode: + # single object mode + self.video_names = sorted( + [ + os.path.join(video_name, obj) + for video_name in self.video_names + for obj in os.listdir(os.path.join(self.gt_folder, video_name)) + ] + ) + + if frames_sampling_mult: + video_names_mult = [] + for video_name in self.video_names: + num_frames = len(os.listdir(os.path.join(self.img_folder, video_name))) + video_names_mult.extend([video_name] * num_frames) + self.video_names = video_names_mult + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + if self.single_object_mode: + video_frame_root = os.path.join( + self.img_folder, os.path.dirname(video_name) + ) + else: + video_frame_root = os.path.join(self.img_folder, video_name) + + video_mask_root = os.path.join(self.gt_folder, video_name) + + if self.is_palette: + segment_loader = PalettisedPNGSegmentLoader(video_mask_root) + else: + segment_loader = MultiplePNGSegmentLoader( + video_mask_root, self.single_object_mode + ) + + all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg"))) + if self.truncate_video > 0: + all_frames = all_frames[: self.truncate_video] + frames = [] + for _, fpath in enumerate(all_frames[:: self.sample_rate]): + fid = int(os.path.basename(fpath).split(".")[0]) + frames.append(VOSFrame(fid, image_path=fpath)) + video = VOSVideo(video_name, idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class SA1BRawDataset(VOSRawDataset): + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + num_frames=1, + mask_area_frac_thresh=1.1, # no filtering by default + uncertain_iou=-1, # no filtering by default + ): + self.img_folder = img_folder + self.gt_folder = gt_folder + self.num_frames = num_frames + self.mask_area_frac_thresh = mask_area_frac_thresh + self.uncertain_iou = uncertain_iou # stability score + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + subset = [ + path.split(".")[0] for path in subset if path.endswith(".jpg") + ] # remove extension + + # Read and process excluded files if provided + if excluded_videos_list_txt is not None: + with g_pathmgr.open(excluded_videos_list_txt, "r") as f: + excluded_files = [os.path.splitext(line.strip())[0] for line in f] + else: + excluded_files = [] + + # Check if it's not in excluded_files and it exists + self.video_names = [ + video_name for video_name in subset if video_name not in excluded_files + ] + + def get_video(self, idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[idx] + + video_frame_path = os.path.join(self.img_folder, video_name + ".jpg") + video_mask_path = os.path.join(self.gt_folder, video_name + ".json") + + segment_loader = SA1BSegmentLoader( + video_mask_path, + mask_area_frac_thresh=self.mask_area_frac_thresh, + video_frame_path=video_frame_path, + uncertain_iou=self.uncertain_iou, + ) + + frames = [] + for frame_idx in range(self.num_frames): + frames.append(VOSFrame(frame_idx, image_path=video_frame_path)) + video_name = video_name.split("_")[-1] # filename is sa_{int} + # video id needs to be image_id to be able to load correct annotation file during eval + video = VOSVideo(video_name, int(video_name), frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) + + +class JSONRawDataset(VOSRawDataset): + """ + Dataset where the annotation in the format of SA-V json files + """ + + def __init__( + self, + img_folder, + gt_folder, + file_list_txt=None, + excluded_videos_list_txt=None, + sample_rate=1, + rm_unannotated=True, + ann_every=1, + frames_fps=24, + ): + self.gt_folder = gt_folder + self.img_folder = img_folder + self.sample_rate = sample_rate + self.rm_unannotated = rm_unannotated + self.ann_every = ann_every + self.frames_fps = frames_fps + + # Read and process excluded files if provided + excluded_files = [] + if excluded_videos_list_txt is not None: + if isinstance(excluded_videos_list_txt, str): + excluded_videos_lists = [excluded_videos_list_txt] + elif isinstance(excluded_videos_list_txt, ListConfig): + excluded_videos_lists = list(excluded_videos_list_txt) + else: + raise NotImplementedError + + for excluded_videos_list_txt in excluded_videos_lists: + with open(excluded_videos_list_txt, "r") as f: + excluded_files.extend( + [os.path.splitext(line.strip())[0] for line in f] + ) + excluded_files = set(excluded_files) + + # Read the subset defined in file_list_txt + if file_list_txt is not None: + with g_pathmgr.open(file_list_txt, "r") as f: + subset = [os.path.splitext(line.strip())[0] for line in f] + else: + subset = os.listdir(self.img_folder) + + self.video_names = sorted( + [video_name for video_name in subset if video_name not in excluded_files] + ) + + def get_video(self, video_idx): + """ + Given a VOSVideo object, return the mask tensors. + """ + video_name = self.video_names[video_idx] + video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json") + segment_loader = JSONSegmentLoader( + video_json_path=video_json_path, + ann_every=self.ann_every, + frames_fps=self.frames_fps, + ) + + frame_ids = [ + int(os.path.splitext(frame_name)[0]) + for frame_name in sorted( + os.listdir(os.path.join(self.img_folder, video_name)) + ) + ] + + frames = [ + VOSFrame( + frame_id, + image_path=os.path.join( + self.img_folder, f"{video_name}/%05d.jpg" % (frame_id) + ), + ) + for frame_id in frame_ids[:: self.sample_rate] + ] + + if self.rm_unannotated: + # Eliminate the frames that have not been annotated + valid_frame_ids = [ + i * segment_loader.ann_every + for i, annot in enumerate(segment_loader.frame_annots) + if annot is not None and None not in annot + ] + frames = [f for f in frames if f.frame_idx in valid_frame_ids] + + video = VOSVideo(video_name, video_idx, frames) + return video, segment_loader + + def __len__(self): + return len(self.video_names) diff --git a/third_party/sam2/training/dataset/vos_sampler.py b/third_party/sam2/training/dataset/vos_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad84b759d0f66191a84017d17140d128b634ca0 --- /dev/null +++ b/third_party/sam2/training/dataset/vos_sampler.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +from dataclasses import dataclass +from typing import List + +from training.dataset.vos_segment_loader import LazySegments + +MAX_RETRIES = 1000 + + +@dataclass +class SampledFramesAndObjects: + frames: List[int] + object_ids: List[int] + + +class VOSSampler: + def __init__(self, sort_frames=True): + # frames are ordered by frame id when sort_frames is True + self.sort_frames = sort_frames + + def sample(self, video): + raise NotImplementedError() + + +class RandomUniformSampler(VOSSampler): + def __init__( + self, + num_frames, + max_num_objects, + reverse_time_prob=0.0, + ): + self.num_frames = num_frames + self.max_num_objects = max_num_objects + self.reverse_time_prob = reverse_time_prob + + def sample(self, video, segment_loader, epoch=None): + + for retry in range(MAX_RETRIES): + if len(video.frames) < self.num_frames: + raise Exception( + f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." + ) + start = random.randrange(0, len(video.frames) - self.num_frames + 1) + frames = [video.frames[start + step] for step in range(self.num_frames)] + if random.uniform(0, 1) < self.reverse_time_prob: + # Reverse time + frames = frames[::-1] + + # Get first frame object ids + visible_object_ids = [] + loaded_segms = segment_loader.load(frames[0].frame_idx) + if isinstance(loaded_segms, LazySegments): + # LazySegments for SA1BRawDataset + visible_object_ids = list(loaded_segms.keys()) + else: + for object_id, segment in segment_loader.load( + frames[0].frame_idx + ).items(): + if segment.sum(): + visible_object_ids.append(object_id) + + # First frame needs to have at least a target to track + if len(visible_object_ids) > 0: + break + if retry >= MAX_RETRIES - 1: + raise Exception("No visible objects") + + object_ids = random.sample( + visible_object_ids, + min(len(visible_object_ids), self.max_num_objects), + ) + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) + + +class EvalSampler(VOSSampler): + """ + VOS Sampler for evaluation: sampling all the frames and all the objects in a video + """ + + def __init__( + self, + ): + super().__init__() + + def sample(self, video, segment_loader, epoch=None): + """ + Sampling all the frames and all the objects + """ + if self.sort_frames: + # ordered by frame id + frames = sorted(video.frames, key=lambda x: x.frame_idx) + else: + # use the original order + frames = video.frames + object_ids = segment_loader.load(frames[0].frame_idx).keys() + if len(object_ids) == 0: + raise Exception("First frame of the video has no objects") + + return SampledFramesAndObjects(frames=frames, object_ids=object_ids) diff --git a/third_party/sam2/training/dataset/vos_segment_loader.py b/third_party/sam2/training/dataset/vos_segment_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..27e17010cc8b010e103c3ac399689d80da7cfde9 --- /dev/null +++ b/third_party/sam2/training/dataset/vos_segment_loader.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import json +import os + +import numpy as np +import pandas as pd +import torch + +from PIL import Image as PILImage + +try: + from pycocotools import mask as mask_utils +except: + pass + + +class JSONSegmentLoader: + def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): + # Annotations in the json are provided every ann_every th frame + self.ann_every = ann_every + # Ids of the objects to consider when sampling this video + self.valid_obj_ids = valid_obj_ids + with open(video_json_path, "r") as f: + data = json.load(f) + if isinstance(data, list): + self.frame_annots = data + elif isinstance(data, dict): + masklet_field_name = "masklet" if "masklet" in data else "masks" + self.frame_annots = data[masklet_field_name] + if "fps" in data: + if isinstance(data["fps"], list): + annotations_fps = int(data["fps"][0]) + else: + annotations_fps = int(data["fps"]) + assert frames_fps % annotations_fps == 0 + self.ann_every = frames_fps // annotations_fps + else: + raise NotImplementedError + + def load(self, frame_id, obj_ids=None): + assert frame_id % self.ann_every == 0 + rle_mask = self.frame_annots[frame_id // self.ann_every] + + valid_objs_ids = set(range(len(rle_mask))) + if self.valid_obj_ids is not None: + # Remove the masklets that have been filtered out for this video + valid_objs_ids &= set(self.valid_obj_ids) + if obj_ids is not None: + # Only keep the objects that have been sampled + valid_objs_ids &= set(obj_ids) + valid_objs_ids = sorted(list(valid_objs_ids)) + + # Construct rle_masks_filtered that only contains the rle masks we are interested in + id_2_idx = {} + rle_mask_filtered = [] + for obj_id in valid_objs_ids: + if rle_mask[obj_id] is not None: + id_2_idx[obj_id] = len(rle_mask_filtered) + rle_mask_filtered.append(rle_mask[obj_id]) + else: + id_2_idx[obj_id] = None + + # Decode the masks + raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( + 2, 0, 1 + ) # (num_obj, h, w) + segments = {} + for obj_id in valid_objs_ids: + if id_2_idx[obj_id] is None: + segments[obj_id] = None + else: + idx = id_2_idx[obj_id] + segments[obj_id] = raw_segments[idx] + return segments + + def get_valid_obj_frames_ids(self, num_frames_min=None): + # For each object, find all the frames with a valid (not None) mask + num_objects = len(self.frame_annots[0]) + + # The result dict associates each obj_id with the id of its valid frames + res = {obj_id: [] for obj_id in range(num_objects)} + + for annot_idx, annot in enumerate(self.frame_annots): + for obj_id in range(num_objects): + if annot[obj_id] is not None: + res[obj_id].append(int(annot_idx * self.ann_every)) + + if num_frames_min is not None: + # Remove masklets that have less than num_frames_min valid masks + for obj_id, valid_frames in list(res.items()): + if len(valid_frames) < num_frames_min: + res.pop(obj_id) + + return res + + +class PalettisedPNGSegmentLoader: + def __init__(self, video_png_root): + """ + SegmentLoader for datasets with masks stored as palettised PNGs. + video_png_root: the folder contains all the masks stored in png + """ + self.video_png_root = video_png_root + # build a mapping from frame id to their PNG mask path + # note that in some datasets, the PNG paths could have more + # than 5 digits, e.g. "00000000.png" instead of "00000.png" + png_filenames = os.listdir(self.video_png_root) + self.frame_id_to_png_filename = {} + for filename in png_filenames: + frame_id, _ = os.path.splitext(filename) + self.frame_id_to_png_filename[int(frame_id)] = filename + + def load(self, frame_id): + """ + load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # check the path + mask_path = os.path.join( + self.video_png_root, self.frame_id_to_png_filename[frame_id] + ) + + # load the mask + masks = PILImage.open(mask_path).convert("P") + masks = np.array(masks) + + object_id = pd.unique(masks.flatten()) + object_id = object_id[object_id != 0] # remove background (0) + + # convert into N binary segmentation masks + binary_segments = {} + for i in object_id: + bs = masks == i + binary_segments[i] = torch.from_numpy(bs) + + return binary_segments + + def __len__(self): + return + + +class MultiplePNGSegmentLoader: + def __init__(self, video_png_root, single_object_mode=False): + """ + video_png_root: the folder contains all the masks stored in png + single_object_mode: whether to load only a single object at a time + """ + self.video_png_root = video_png_root + self.single_object_mode = single_object_mode + # read a mask to know the resolution of the video + if self.single_object_mode: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] + else: + tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] + tmp_mask = np.array(PILImage.open(tmp_mask_path)) + self.H = tmp_mask.shape[0] + self.W = tmp_mask.shape[1] + if self.single_object_mode: + self.obj_id = ( + int(video_png_root.split("/")[-1]) + 1 + ) # offset by 1 as bg is 0 + else: + self.obj_id = None + + def load(self, frame_id): + if self.single_object_mode: + return self._load_single_png(frame_id) + else: + return self._load_multiple_pngs(frame_id) + + def _load_single_png(self, frame_id): + """ + load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") + binary_segments = {} + + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + # if png doesn't exist, empty mask + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[self.obj_id] = torch.from_numpy(mask > 0) + return binary_segments + + def _load_multiple_pngs(self, frame_id): + """ + load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') + Args: + frame_id: int, define the mask path + Return: + binary_segments: dict + """ + # get the path + all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) + num_objects = len(all_objects) + assert num_objects > 0 + + # load the masks + binary_segments = {} + for obj_folder in all_objects: + # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder + obj_id = int(obj_folder.split("/")[-1]) + obj_id = obj_id + 1 # offset 1 as bg is 0 + mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") + if os.path.exists(mask_path): + mask = np.array(PILImage.open(mask_path)) + else: + mask = np.zeros((self.H, self.W), dtype=bool) + binary_segments[obj_id] = torch.from_numpy(mask > 0) + + return binary_segments + + def __len__(self): + return + + +class LazySegments: + """ + Only decodes segments that are actually used. + """ + + def __init__(self): + self.segments = {} + self.cache = {} + + def __setitem__(self, key, item): + self.segments[key] = item + + def __getitem__(self, key): + if key in self.cache: + return self.cache[key] + rle = self.segments[key] + mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] + self.cache[key] = mask + return mask + + def __contains__(self, key): + return key in self.segments + + def __len__(self): + return len(self.segments) + + def keys(self): + return self.segments.keys() + + +class SA1BSegmentLoader: + def __init__( + self, + video_mask_path, + mask_area_frac_thresh=1.1, + video_frame_path=None, + uncertain_iou=-1, + ): + with open(video_mask_path, "r") as f: + self.frame_annots = json.load(f) + + if mask_area_frac_thresh <= 1.0: + # Lazily read frame + orig_w, orig_h = PILImage.open(video_frame_path).size + area = orig_w * orig_h + + self.frame_annots = self.frame_annots["annotations"] + + rle_masks = [] + for frame_annot in self.frame_annots: + if not frame_annot["area"] > 0: + continue + if ("uncertain_iou" in frame_annot) and ( + frame_annot["uncertain_iou"] < uncertain_iou + ): + # uncertain_iou is stability score + continue + if ( + mask_area_frac_thresh <= 1.0 + and (frame_annot["area"] / area) >= mask_area_frac_thresh + ): + continue + rle_masks.append(frame_annot["segmentation"]) + + self.segments = LazySegments() + for i, rle in enumerate(rle_masks): + self.segments[i] = rle + + def load(self, frame_idx): + return self.segments diff --git a/third_party/sam2/training/loss_fns.py b/third_party/sam2/training/loss_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..d281b1a9c059771ee0ae3a4d4426f1e445178110 --- /dev/null +++ b/third_party/sam2/training/loss_fns.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +from training.trainer import CORE_LOSS_KEY + +from training.utils.distributed import get_world_size, is_dist_avail_and_initialized + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + Returns: + Dice loss tensor + """ + inputs = inputs.sigmoid() + if loss_on_multimask: + # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks + assert inputs.dim() == 4 and targets.dim() == 4 + # flatten spatial dimension while keeping multimask channel dimension + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + loss_on_multimask: True if multimask prediction is enabled + Returns: + focal loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + # loss is [N, M, H, W] where M corresponds to multiple predicted masks + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects # average over spatial dims + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + pred_ious: A float tensor containing the predicted IoUs scores per mask + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + use_l1_loss: Whether to use L1 loss is used instead of MSE loss + Returns: + IoU loss tensor + """ + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + ): + """ + This class computes the multi-step multi-mask and IoU losses. + Args: + weight_dict: dict containing weights for focal, dice, iou losses + focal_alpha: alpha for sigmoid focal loss + focal_gamma: gamma for sigmoid focal loss + supervise_all_iou: if True, back-prop iou losses for all predicted masks + iou_use_l1_loss: use L1 loss instead of MSE loss for iou + pred_obj_scores: if True, compute loss for object scores + focal_gamma_obj_score: gamma for sigmoid focal loss on object scores + focal_alpha_obj_score: alpha for sigmoid focal loss on object scores + """ + + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float + ) # Number of objects is fixed within a batch + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / get_world_size(), min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + and also the MAE or MSE loss between predicted IoUs and actual IoUs. + + Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors + of shape [N, M, H, W], where M could be 1 or larger, corresponding to + one or multiple predicted masks from a click. + + We back-propagate focal, dice losses only on the prediction channel + with the lowest focal+dice loss between predicted mask and ground-truth. + If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks. + """ + + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 # [N, 1, H, W] + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + # accumulate the loss over prediction steps + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + # get focal, dice and iou loss on all output masks in a prediction step + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + # take the mask indices with the smallest focal + dice loss for back propagation + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + # calculate the iou prediction and slot losses only in the index + # with the minimum loss for each mask (to be consistent w/ SAM) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + # backprop focal, dice and iou loss only if obj present + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + # sum over batch dimension (note that the losses are already divided by num_objects) + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + + return reduced_loss diff --git a/third_party/sam2/training/model/__init__.py b/third_party/sam2/training/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/training/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/training/model/sam2.py b/third_party/sam2/training/model/sam2.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7567c4dc99942d48e5890529ba9e3ca265e02d --- /dev/null +++ b/third_party/sam2/training/model/sam2.py @@ -0,0 +1,541 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +import torch.distributed +from sam2.modeling.sam2_base import SAM2Base +from sam2.modeling.sam2_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) + +from sam2.utils.misc import concat_points + +from training.utils.data_utils import BatchedVideoDatapoint + + +class SAM2Train(SAM2Base): + def __init__( + self, + image_encoder, + memory_attention=None, + memory_encoder=None, + prob_to_use_pt_input_for_train=0.0, + prob_to_use_pt_input_for_eval=0.0, + prob_to_use_box_input_for_train=0.0, + prob_to_use_box_input_for_eval=0.0, + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train=False, + rand_frames_to_correct_for_eval=False, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame=7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval="center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train=0.0, + use_act_ckpt_iterative_pt_sampling=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + freeze_image_encoder=False, + **kwargs, + ): + super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs) + self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info( + f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}" + ) + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + + if freeze_image_encoder: + for p in self.image_encoder.parameters(): + p.requires_grad = False + + def forward(self, input: BatchedVideoDatapoint): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image(input.flat_img_batch) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {"backbone_fpn": None, "vision_pos_enc": None} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + return previous_stages_out + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repetitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + # gt_masks_per_frame = { + # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + # for stage_id, targets in enumerate(input.find_targets) + # } + gt_masks_per_frame = { + stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, masks in enumerate(input.masks) + } + # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = input.num_frames + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + prob_to_use_box_input = self.prob_to_use_box_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def forward_tracking( + self, backbone_out, input: BatchedVideoDatapoint, return_dict=False + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + # img_ids = input.find_inputs[stage_id].img_ids + img_ids = input.flat_obj_to_img_idx[stage_id] + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame( + input.flat_img_batch, img_ids + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks. + prev_sam_mask_logits=None, # The previously predicted SAM mask logits. + frames_to_add_correction_pt=None, + gt_masks=None, + ): + if frames_to_add_correction_pt is None: + frames_to_add_correction_pt = [] + current_out, sam_outputs, high_res_features, pix_feat = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + point_inputs, final_sam_outputs = self._iter_correct_pt_sampling( + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = final_sam_outputs + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out + + def _iter_correct_pt_sampling( + self, + is_init_cond_frame, + point_inputs, + gt_masks, + high_res_features, + pix_feat_with_mem, + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + object_score_logits, + current_out, + ): + + assert gt_masks is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks, + pred_masks=pred_for_new_pt, + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: + sam_outputs = torch.utils.checkpoint.checkpoint( + self._forward_sam_heads, + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + use_reentrant=False, + ) + else: + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + _, + object_score_logits, + ) = sam_outputs + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + return point_inputs, sam_outputs diff --git a/third_party/sam2/training/optimizer.py b/third_party/sam2/training/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ae159663f6efc2dac4f5ffa3b1c91b97a78dec76 --- /dev/null +++ b/third_party/sam2/training/optimizer.py @@ -0,0 +1,502 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import fnmatch +import inspect +import itertools +import logging +import types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import hydra + +import torch +import torch.nn as nn +from omegaconf import DictConfig +from torch import Tensor + + +class Optimizer: + def __init__(self, optimizer, schedulers=None) -> None: + self.optimizer = optimizer + self.schedulers = schedulers + self._validate_optimizer_schedulers() + self.step_schedulers(0.0, 0) + + def _validate_optimizer_schedulers(self): + if self.schedulers is None: + return + for _, set_of_schedulers in enumerate(self.schedulers): + for option, _ in set_of_schedulers.items(): + assert option in self.optimizer.defaults, ( + "Optimizer option " + f"{option} not found in {self.optimizer}. Valid options are " + f"{self.optimizer.defaults.keys()}" + ) + + def step_schedulers(self, where: float, step: int) -> None: + if self.schedulers is None: + return + for i, param_group in enumerate(self.optimizer.param_groups): + for option, scheduler in self.schedulers[i].items(): + if "step" in inspect.signature(scheduler.__call__).parameters: + new_value = scheduler(step=step, where=where) + elif ( + hasattr(scheduler, "scheduler") + and "step" + in inspect.signature(scheduler.scheduler.__call__).parameters + ): + # To handle ValueScaler wrappers + new_value = scheduler(step=step, where=where) + else: + new_value = scheduler(where) + param_group[option] = new_value + + def step(self, where, step, closure=None): + self.step_schedulers(where, step) + return self.optimizer.step(closure) + + def zero_grad(self, *args, **kwargs): + return self.optimizer.zero_grad(*args, **kwargs) + + +def set_default_parameters( + scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str] +) -> None: + """Set up the "default" scheduler with the right parameters. + + Args: + scheduler_cgfs: A list of scheduler configs, where each scheduler also + specifies which parameters it applies to, based on the names of parameters + or the class of the modules. At most one scheduler is allowed to skip this + specification, which is used as a "default" specification for any remaining + parameters. + all_parameter_names: Names of all the parameters to consider. + """ + constraints = [ + scheduler_cfg.parameter_names + for scheduler_cfg in scheduler_cfgs + if scheduler_cfg.parameter_names is not None + ] + if len(constraints) == 0: + default_params = set(all_parameter_names) + else: + default_params = all_parameter_names - set.union(*constraints) + default_count = 0 + for scheduler_cfg in scheduler_cfgs: + if scheduler_cfg.parameter_names is None: + scheduler_cfg.parameter_names = default_params + default_count += 1 + assert default_count <= 1, "Only one scheduler per option can be default" + if default_count == 0: + # No default scheduler specified, add a default, but without any scheduler + # for that option + scheduler_cfgs.append({"parameter_names": default_params}) + + +def name_constraints_to_parameters( + param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor] +) -> List[torch.nn.Parameter]: + """Return parameters which match the intersection of parameter constraints. + + Note that this returns the parameters themselves, not their names. + + Args: + param_constraints: A list, with each element being a set of allowed parameters. + named_parameters: Mapping from a parameter name to the parameter itself. + + Returns: + A list containing the parameters which overlap with _each_ constraint set from + param_constraints. + """ + matching_names = set.intersection(*param_constraints) + return [value for name, value in named_parameters.items() if name in matching_names] + + +def map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs: Iterable[List[Dict]], + named_parameters: Dict[str, Tensor], +) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]: + """Produce parameter groups corresponding to all the scheduler configs. + + Takes all the scheduler configs, each of which applies to a specific optimizer + option (like "lr" or "weight_decay") and has a set of parameter names which it + applies to, and produces a final set of param groups where each param group + covers all the options which apply to a particular set of parameters. + + Args: + all_scheduler_cfgs: All the scheduler configs covering every option. + named_parameters: Mapping from a parameter name to the parameter itself. + Returns: + Tuple of lists of schedulers and param_groups, where schedulers[i] + applies to param_groups[i]. + """ + + scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs) + schedulers = [] + param_groups = [] + for scheduler_cfgs in scheduler_cfgs_per_param_group: + param_constraints = [ + scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs + ] + matching_parameters = name_constraints_to_parameters( + param_constraints, named_parameters + ) + if len(matching_parameters) == 0: # If no overlap of parameters, skip + continue + schedulers_for_group = { + scheduler_cfg["option"]: scheduler_cfg["scheduler"] + for scheduler_cfg in scheduler_cfgs + if "option" in scheduler_cfg + } + schedulers.append(schedulers_for_group) + param_groups.append({"params": matching_parameters}) + return schedulers, param_groups + + +def validate_param_group_params(param_groups: List[Dict], model: nn.Module): + """Check that the param groups are non-overlapping and cover all the parameters. + + Args: + param_groups: List of all param groups + model: Model to validate against. The check ensures that all the model + parameters are part of param_groups + """ + for pg in param_groups: + # no param should be repeated within a group + assert len(pg["params"]) == len(set(pg["params"])) + parameters = [set(param_group["params"]) for param_group in param_groups] + model_parameters = {parameter for _, parameter in model.named_parameters()} + for p1, p2 in itertools.permutations(parameters, 2): + assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint" + assert set.union(*parameters) == model_parameters, ( + "Scheduler generated param_groups must include all parameters of the model." + f" Found {len(set.union(*parameters))} params whereas model has" + f" {len(model_parameters)} params" + ) + + +def unix_module_cls_pattern_to_parameter_names( + filter_module_cls_names: List[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_module_cls_names. + + Args: + filter_module_cls_names: A list of filter strings containing class names, like + ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + if filter_module_cls_names is None: + return set() + allowed_parameter_names = [] + for module_cls_name in filter_module_cls_names: + module_cls = hydra.utils.get_class(module_cls_name) + if module_cls not in module_cls_to_param_names: + raise AssertionError( + f"module_cls_name {module_cls_name} does not " + "match any classes in the model" + ) + matching_parameters = module_cls_to_param_names[module_cls] + assert ( + len(matching_parameters) > 0 + ), f"module_cls_name {module_cls_name} does not contain any parameters in the model" + logging.info( + f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " + ) + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def unix_param_pattern_to_parameter_names( + filter_param_names: Optional[List[str]], + parameter_names: Dict[str, torch.Tensor], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_param_names. + + Args: + filter_param_names: A list of unix-style filter strings with optional + wildcards, like ["block.2.*", "block.2.linear.weight"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + + if filter_param_names is None: + return set() + allowed_parameter_names = [] + for param_name in filter_param_names: + matching_parameters = set(fnmatch.filter(parameter_names, param_name)) + assert ( + len(matching_parameters) >= 1 + ), f"param_name {param_name} does not match any parameters in the model" + logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def _unix_pattern_to_parameter_names( + scheduler_cfg: DictConfig, + parameter_names: Set[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in scheduler_cfg. + + Args: + scheduler_cfg: The config for the scheduler + parameter_names: The set of all parameter names which will be filtered + """ + if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: + return None + return unix_param_pattern_to_parameter_names( + scheduler_cfg.get("param_names"), parameter_names + ).union( + unix_module_cls_pattern_to_parameter_names( + scheduler_cfg.get("module_cls_names"), module_cls_to_param_names + ) + ) + + +def get_module_cls_to_param_names( + model: nn.Module, param_allowlist: Set[str] = None +) -> Dict[Type, str]: + """Produce a mapping from all the modules classes to the names of parames they own. + + Only counts a parameter as part of the immediate parent module, i.e. recursive + parents do not count. + + Args: + model: Model to iterate over + param_allowlist: If specified, only these param names will be processed + """ + + module_cls_to_params = {} + for module_name, module in model.named_modules(): + module_cls = type(module) + module_cls_to_params.setdefault(module_cls, set()) + for param_name, _ in module.named_parameters(recurse=False): + full_param_name = get_full_parameter_name(module_name, param_name) + if param_allowlist is None or full_param_name in param_allowlist: + module_cls_to_params[module_cls].add(full_param_name) + return module_cls_to_params + + +def construct_optimizer( + model: torch.nn.Module, + optimizer_conf: Any, + options_conf: Mapping[str, List] = None, + param_group_modifiers_conf: List[Callable] = None, + param_allowlist: Optional[Set[str]] = None, + validate_param_groups=True, +) -> Optimizer: + """ + Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer + with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay + Batchnorm and/or no-update 1-D parameters support, based on the config. + + Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling + (LARS): https://arxiv.org/abs/1708.03888 + + Args: + model: model to perform stochastic gradient descent + optimization or ADAM optimization. + optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or + ADAM, still missing the params argument which this function provides to + produce the final optimizer + param_group_modifiers_conf: Optional user specified functions which can modify + the final scheduler configs before the optimizer's param groups are built + param_allowlist: The parameters to optimize. Parameters which are not part of + this allowlist will be skipped. + validate_param_groups: If enabled, valides that the produced param_groups don't + overlap and cover all the model parameters. + """ + if param_allowlist is None: + param_allowlist = {name for name, _ in model.named_parameters()} + + named_parameters = { + name: param + for name, param in model.named_parameters() + if name in param_allowlist + } + + if not options_conf: + optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) + return Optimizer(optimizer) + + all_parameter_names = { + name for name, _ in model.named_parameters() if name in param_allowlist + } + module_cls_to_all_param_names = get_module_cls_to_param_names( + model, param_allowlist + ) + + scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) + all_scheduler_cfgs = [] + for option, scheduler_cfgs in scheduler_cfgs_per_option.items(): + for config in scheduler_cfgs: + config.option = option + config.parameter_names = _unix_pattern_to_parameter_names( + config, all_parameter_names, module_cls_to_all_param_names + ) + set_default_parameters(scheduler_cfgs, all_parameter_names) + all_scheduler_cfgs.append(scheduler_cfgs) + + if param_group_modifiers_conf: + for custom_param_modifier in param_group_modifiers_conf: + custom_param_modifier = hydra.utils.instantiate(custom_param_modifier) + all_scheduler_cfgs = custom_param_modifier( + scheduler_cfgs=all_scheduler_cfgs, model=model + ) + schedulers, param_groups = map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs, named_parameters + ) + if validate_param_groups: + validate_param_group_params(param_groups, model) + optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) + return Optimizer(optimizer, schedulers) + + +def get_full_parameter_name(module_name, param_name): + if module_name == "": + return param_name + return f"{module_name}.{param_name}" + + +class GradientClipper: + """ + Gradient clipping utils that works for DDP + """ + + def __init__(self, max_norm: float = 1.0, norm_type: int = 2): + assert isinstance(max_norm, (int, float)) or max_norm is None + self.max_norm = max_norm if max_norm is None else float(max_norm) + self.norm_type = norm_type + + def __call__(self, model: nn.Module): + if self.max_norm is None: + return # no-op + + nn.utils.clip_grad_norm_( + model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type + ) + + +class ValueScaler: + def __init__(self, scheduler, mult_val: float): + self.scheduler = scheduler + self.mult_val = mult_val + + def __call__(self, *args, **kwargs): + val = self.scheduler(*args, **kwargs) + return val * self.mult_val + + +def rgetattr(obj, rattrs: str = None): + """ + Like getattr(), but supports dotted notation for nested objects. + rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2 + """ + if rattrs is None: + return obj + attrs = rattrs.split(".") + for attr in attrs: + obj = getattr(obj, attr) + return obj + + +def layer_decay_param_modifier( + scheduler_cfgs: List[List[Dict]], + model, + layer_decay_value: float, + layer_decay_min: Optional[float] = None, + apply_to: Optional[str] = None, + overrides: List[Dict] = (), +) -> List[List[Dict]]: + """ + Args + - scheduler_cfgs: a list of omegaconf.ListConfigs. + Each element in the list is a omegaconfg.DictConfig with the following structure + { + "scheduler": + "option": possible options are "lr", "weight_decay" etc. + "parameter_names": Set of str indicating param names that this scheduler applies to + } + - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and + and a method get_num_layers. + Alternatively, use apply_to argument to select a specific component of the model. + - layer_decay_value: float + - layer_decay_min: min val for layer decay + - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to + - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value". + Returns + - scheduler_configs: same structure as the input, elements can be modified + """ + model = rgetattr(model, apply_to) + num_layers = model.get_num_layers() + 1 + layer_decays = [ + layer_decay_value ** (num_layers - i) for i in range(num_layers + 1) + ] + if layer_decay_min is not None: + layer_decays = [max(val, layer_decay_min) for val in layer_decays] + final_scheduler_cfgs = [] + # scheduler_cfgs is a list of lists + for scheduler_cfg_group in scheduler_cfgs: + curr_cfg_group = [] + # scheduler_cfg_group is a list of dictionaries + for scheduler_cfg in scheduler_cfg_group: + if scheduler_cfg["option"] != "lr": + curr_cfg_group.append(scheduler_cfg) + continue + # Need sorted so that the list of parameter names is deterministic and consistent + # across re-runs of this job. Else it was causing issues with loading the optimizer + # state during a job restart (D38591759) + parameter_names = sorted(scheduler_cfg["parameter_names"]) + + # Only want one cfg group per layer + layer_cfg_groups = {} + for param_name in parameter_names: + layer_id = num_layers + this_scale = layer_decays[layer_id] + if param_name.startswith(apply_to): + layer_id = model.get_layer_id(param_name) + this_scale = layer_decays[layer_id] + # Overrides + for override in overrides: + if fnmatch.fnmatchcase(param_name, override["pattern"]): + this_scale = float(override["value"]) + layer_id = override["pattern"] + break + + if layer_id not in layer_cfg_groups: + curr_param = { + "option": scheduler_cfg["option"], + "scheduler": ValueScaler( + scheduler_cfg["scheduler"], this_scale + ), + "parameter_names": {param_name}, + } + else: + curr_param = layer_cfg_groups[layer_id] + curr_param["parameter_names"].add(param_name) + layer_cfg_groups[layer_id] = curr_param + + for layer_cfg in layer_cfg_groups.values(): + curr_cfg_group.append(layer_cfg) + + final_scheduler_cfgs.append(curr_cfg_group) + return final_scheduler_cfgs diff --git a/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py b/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5ed2fc77deecf87c8d823bb3fdcf3cb856fc94 --- /dev/null +++ b/third_party/sam2/training/scripts/sav_frame_extraction_submitit.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import argparse +import os +from pathlib import Path + +import cv2 + +import numpy as np +import submitit +import tqdm + + +def get_args_parser(): + parser = argparse.ArgumentParser( + description="[SA-V Preprocessing] Extracting JPEG frames", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ------------ + # DATA + # ------------ + data_parser = parser.add_argument_group( + title="SA-V dataset data root", + description="What data to load and how to process it.", + ) + data_parser.add_argument( + "--sav-vid-dir", + type=str, + required=True, + help=("Where to find the SAV videos"), + ) + data_parser.add_argument( + "--sav-frame-sample-rate", + type=int, + default=4, + help="Rate at which to sub-sample frames", + ) + + # ------------ + # LAUNCH + # ------------ + launch_parser = parser.add_argument_group( + title="Cluster launch settings", + description="Number of jobs and retry settings.", + ) + launch_parser.add_argument( + "--n-jobs", + type=int, + required=True, + help="Shard the run over this many jobs.", + ) + launch_parser.add_argument( + "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes." + ) + launch_parser.add_argument( + "--partition", type=str, required=True, help="Partition to launch on." + ) + launch_parser.add_argument( + "--account", type=str, required=True, help="Partition to launch on." + ) + launch_parser.add_argument("--qos", type=str, required=True, help="QOS.") + + # ------------ + # OUTPUT + # ------------ + output_parser = parser.add_argument_group( + title="Setting for results output", description="Where and how to save results." + ) + output_parser.add_argument( + "--output-dir", + type=str, + required=True, + help=("Where to dump the extracted jpeg frames"), + ) + output_parser.add_argument( + "--slurm-output-root-dir", + type=str, + required=True, + help=("Where to save slurm outputs"), + ) + return parser + + +def decode_video(video_path: str): + assert os.path.exists(video_path) + video = cv2.VideoCapture(video_path) + video_frames = [] + while video.isOpened(): + ret, frame = video.read() + if ret: + video_frames.append(frame) + else: + break + return video_frames + + +def extract_frames(video_path, sample_rate): + frames = decode_video(video_path) + return frames[::sample_rate] + + +def submitit_launch(video_paths, sample_rate, save_root): + for path in tqdm.tqdm(video_paths): + frames = extract_frames(path, sample_rate) + output_folder = os.path.join(save_root, Path(path).stem) + if not os.path.exists(output_folder): + os.makedirs(output_folder) + for fid, frame in enumerate(frames): + frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg") + cv2.imwrite(frame_path, frame) + print(f"Saved output to {save_root}") + + +if __name__ == "__main__": + parser = get_args_parser() + args = parser.parse_args() + + sav_vid_dir = args.sav_vid_dir + save_root = args.output_dir + sample_rate = args.sav_frame_sample_rate + + # List all SA-V videos + mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")]) + mp4_files = np.array(mp4_files) + chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)] + + print(f"Processing videos in: {sav_vid_dir}") + print(f"Processing {len(mp4_files)} files") + print(f"Beginning processing in {args.n_jobs} processes") + + # Submitit params + jobs_dir = os.path.join(args.slurm_output_root_dir, "%j") + cpus_per_task = 4 + executor = submitit.AutoExecutor(folder=jobs_dir) + executor.update_parameters( + timeout_min=args.timeout, + gpus_per_node=0, + tasks_per_node=1, + slurm_array_parallelism=args.n_jobs, + cpus_per_task=cpus_per_task, + slurm_partition=args.partition, + slurm_account=args.account, + slurm_qos=args.qos, + ) + executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"]) + + # Launch + jobs = [] + with executor.batch(): + for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)): + job = executor.submit( + submitit_launch, + video_paths=mp4_chunk, + sample_rate=sample_rate, + save_root=save_root, + ) + jobs.append(job) + + for j in jobs: + print(f"Slurm JobID: {j.job_id}") + print(f"Saving outputs to {save_root}") + print(f"Slurm outputs at {args.slurm_output_root_dir}") diff --git a/third_party/sam2/training/train.py b/third_party/sam2/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..db06123fcb1b2ba8ff5f462dbb7411d42a57c9a0 --- /dev/null +++ b/third_party/sam2/training/train.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import sys +import traceback +from argparse import ArgumentParser + +import submitit +import torch + +from hydra import compose, initialize_config_module +from hydra.utils import instantiate + +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf + +from training.utils.train_utils import makedir, register_omegaconf_resolvers + +os.environ["HYDRA_FULL_ERROR"] = "1" + + +def single_proc_run(local_rank, main_port, cfg, world_size): + """Single GPU process""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(main_port) + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + try: + register_omegaconf_resolvers() + except Exception as e: + logging.info(e) + + trainer = instantiate(cfg.trainer, _recursive_=False) + trainer.run() + + +def single_node_runner(cfg, main_port: int): + assert cfg.launcher.num_nodes == 1 + num_proc = cfg.launcher.gpus_per_node + torch.multiprocessing.set_start_method( + "spawn" + ) # CUDA runtime does not support `fork` + if num_proc == 1: + # directly call single_proc so we can easily set breakpoints + # mp.spawn does not let us set breakpoints + single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc) + else: + mp_runner = torch.multiprocessing.start_processes + args = (main_port, cfg, num_proc) + # Note: using "fork" below, "spawn" causes time and error regressions. Using + # spawn changes the default multiprocessing context to spawn, which doesn't + # interact well with the dataloaders (likely due to the use of OpenCV). + mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn") + + +def format_exception(e: Exception, limit=20): + traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit)) + return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}" + + +class SubmititRunner(submitit.helpers.Checkpointable): + """A callable which is passed to submitit to launch the jobs.""" + + def __init__(self, port, cfg): + self.cfg = cfg + self.port = port + self.has_setup = False + + def run_trainer(self): + job_env = submitit.JobEnvironment() + # Need to add this again so the hydra.job.set_env PYTHONPATH + # is also set when launching jobs. + add_pythonpath_to_sys_path() + os.environ["MASTER_ADDR"] = job_env.hostnames[0] + os.environ["MASTER_PORT"] = str(self.port) + os.environ["RANK"] = str(job_env.global_rank) + os.environ["LOCAL_RANK"] = str(job_env.local_rank) + os.environ["WORLD_SIZE"] = str(job_env.num_tasks) + + register_omegaconf_resolvers() + cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + trainer = instantiate(cfg_resolved.trainer, _recursive_=False) + trainer.run() + + def __call__(self): + job_env = submitit.JobEnvironment() + self.setup_job_info(job_env.job_id, job_env.global_rank) + try: + self.run_trainer() + except Exception as e: + # Log the exception. Then raise it again (as what SubmititRunner currently does). + message = format_exception(e) + logging.error(message) + raise e + + def setup_job_info(self, job_id, rank): + """Set up slurm job info""" + self.job_info = { + "job_id": job_id, + "rank": rank, + "cluster": self.cfg.get("cluster", None), + "experiment_log_dir": self.cfg.launcher.experiment_log_dir, + } + + self.has_setup = True + + +def add_pythonpath_to_sys_path(): + if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]: + return + sys.path = os.environ["PYTHONPATH"].split(":") + sys.path + + +def main(args) -> None: + cfg = compose(config_name=args.config) + if cfg.launcher.experiment_log_dir is None: + cfg.launcher.experiment_log_dir = os.path.join( + os.getcwd(), "sam2_logs", args.config + ) + print("###################### Train App Config ####################") + print(OmegaConf.to_yaml(cfg)) + print("############################################################") + + add_pythonpath_to_sys_path() + makedir(cfg.launcher.experiment_log_dir) + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg)) + + cfg_resolved = OmegaConf.to_container(cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True)) + + submitit_conf = cfg.get("submitit", None) + assert submitit_conf is not None, "Missing submitit config" + + submitit_dir = cfg.launcher.experiment_log_dir + submitit_dir = os.path.join(submitit_dir, "submitit_logs") + # Priotrize cmd line args + cfg.launcher.gpus_per_node = ( + args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node + ) + cfg.launcher.num_nodes = ( + args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes + ) + submitit_conf.use_cluster = ( + args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster + ) + if submitit_conf.use_cluster: + executor = submitit.AutoExecutor(folder=submitit_dir) + submitit_conf.partition = ( + args.partition + if args.partition is not None + else submitit_conf.get("partition", None) + ) + submitit_conf.account = ( + args.account + if args.account is not None + else submitit_conf.get("account", None) + ) + submitit_conf.qos = ( + args.qos if args.qos is not None else submitit_conf.get("qos", None) + ) + job_kwargs = { + "timeout_min": 60 * submitit_conf.timeout_hour, + "name": ( + submitit_conf.name if hasattr(submitit_conf, "name") else args.config + ), + "slurm_partition": submitit_conf.partition, + "gpus_per_node": cfg.launcher.gpus_per_node, + "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU + "cpus_per_task": submitit_conf.cpus_per_task, + "nodes": cfg.launcher.num_nodes, + "slurm_additional_parameters": { + "exclude": " ".join(submitit_conf.get("exclude_nodes", [])), + }, + } + if "include_nodes" in submitit_conf: + assert ( + len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes + ), "Not enough nodes" + job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( + submitit_conf["include_nodes"] + ) + if submitit_conf.account is not None: + job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account + if submitit_conf.qos is not None: + job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos + + if submitit_conf.get("mem_gb", None) is not None: + job_kwargs["mem_gb"] = submitit_conf.mem_gb + elif submitit_conf.get("mem", None) is not None: + job_kwargs["slurm_mem"] = submitit_conf.mem + + if submitit_conf.get("constraints", None) is not None: + job_kwargs["slurm_constraint"] = submitit_conf.constraints + + if submitit_conf.get("comment", None) is not None: + job_kwargs["slurm_comment"] = submitit_conf.comment + + # Supports only cpu-bind option within srun_args. New options can be added here + if submitit_conf.get("srun_args", None) is not None: + job_kwargs["slurm_srun_args"] = [] + if submitit_conf.srun_args.get("cpu_bind", None) is not None: + job_kwargs["slurm_srun_args"].extend( + ["--cpu-bind", submitit_conf.srun_args.cpu_bind] + ) + + print("###################### SLURM Config ####################") + print(job_kwargs) + print("##########################################") + executor.update_parameters(**job_kwargs) + + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + runner = SubmititRunner(main_port, cfg) + job = executor.submit(runner) + print(f"Submitit Job ID: {job.job_id}") + runner.setup_job_info(job.job_id, rank=0) + else: + cfg.launcher.num_nodes = 1 + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + single_node_runner(cfg, main_port) + + +if __name__ == "__main__": + + initialize_config_module("sam2", version_base="1.2") + parser = ArgumentParser() + parser.add_argument( + "-c", + "--config", + required=True, + type=str, + help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)", + ) + parser.add_argument( + "--use-cluster", + type=int, + default=None, + help="whether to launch on a cluster, 0: run locally, 1: run on a cluster", + ) + parser.add_argument("--partition", type=str, default=None, help="SLURM partition") + parser.add_argument("--account", type=str, default=None, help="SLURM account") + parser.add_argument("--qos", type=str, default=None, help="SLURM qos") + parser.add_argument( + "--num-gpus", type=int, default=None, help="number of GPUS per node" + ) + parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes") + args = parser.parse_args() + args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None + register_omegaconf_resolvers() + main(args) diff --git a/third_party/sam2/training/trainer.py b/third_party/sam2/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7c27b5145e2c03848331345ac246296accbc1d --- /dev/null +++ b/third_party/sam2/training/trainer.py @@ -0,0 +1,1113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +import logging +import math +import os +import time +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional + +import numpy as np + +import torch +import torch.distributed as dist +import torch.nn as nn +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr + +from training.optimizer import construct_optimizer + +from training.utils.checkpoint_utils import ( + assert_skipped_parameters_are_frozen, + exclude_params_matching_unix_pattern, + load_state_dict_into_model, + with_check_parameter_frozen, +) +from training.utils.data_utils import BatchedVideoDatapoint +from training.utils.distributed import all_reduce_max, barrier, get_rank + +from training.utils.logger import Logger, setup_logging + +from training.utils.train_utils import ( + AverageMeter, + collect_dict_keys, + DurationMeter, + get_amp_type, + get_machine_local_and_dist_rank, + get_resume_checkpoint, + human_readable_time, + is_dist_avail_and_initialized, + log_env_variables, + makedir, + MemMeter, + Phase, + ProgressMeter, + set_seeds, + setup_distributed_backend, +) + + +CORE_LOSS_KEY = "core_loss" + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +@dataclass +class OptimAMPConf: + enabled: bool = False + amp_dtype: str = "float16" + + +@dataclass +class OptimConf: + optimizer: torch.optim.Optimizer = None + options: Optional[Dict[str, Any]] = None + param_group_modifiers: Optional[List] = None + amp: Optional[Dict[str, Any]] = None + gradient_clip: Any = None + gradient_logger: Any = None + + def __post_init__(self): + # amp + if not isinstance(self.amp, OptimAMPConf): + if self.amp is None: + self.amp = {} + assert isinstance(self.amp, Mapping) + self.amp = OptimAMPConf(**self.amp) + + +@dataclass +class DistributedConf: + backend: Optional[str] = None # inferred from accelerator type + comms_dtype: Optional[str] = None + find_unused_parameters: bool = False + timeout_mins: int = 30 + + +@dataclass +class CudaConf: + cudnn_deterministic: bool = False + cudnn_benchmark: bool = True + allow_tf32: bool = False + # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul + matmul_allow_tf32: Optional[bool] = None + # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn + cudnn_allow_tf32: Optional[bool] = None + + +@dataclass +class CheckpointConf: + save_dir: str + save_freq: int + save_list: List[int] = field(default_factory=list) + model_weight_initializer: Any = None + save_best_meters: List[str] = None + skip_saving_parameters: List[str] = field(default_factory=list) + initialize_after_preemption: Optional[bool] = None + # if not None, training will be resumed from this checkpoint + resume_from: Optional[str] = None + + def infer_missing(self): + if self.initialize_after_preemption is None: + with_skip_saving = len(self.skip_saving_parameters) > 0 + self.initialize_after_preemption = with_skip_saving + return self + + +@dataclass +class LoggingConf: + log_dir: str + log_freq: int # In iterations + tensorboard_writer: Any + log_level_primary: str = "INFO" + log_level_secondary: str = "ERROR" + log_scalar_frequency: int = 100 + log_visual_frequency: int = 100 + scalar_keys_to_log: Optional[Dict[str, Any]] = None + log_batch_stats: bool = False + + +class Trainer: + """ + Trainer supporting the DDP training strategies. + """ + + EPSILON = 1e-8 + + def __init__( + self, + *, # the order of these args can change at any time, so they are keyword-only + data: Dict[str, Any], + model: Dict[str, Any], + logging: Dict[str, Any], + checkpoint: Dict[str, Any], + max_epochs: int, + mode: str = "train", + accelerator: str = "cuda", + seed_value: int = 123, + val_epoch_freq: int = 1, + distributed: Dict[str, bool] = None, + cuda: Dict[str, bool] = None, + env_variables: Optional[Dict[str, Any]] = None, + optim: Optional[Dict[str, Any]] = None, + optim_overrides: Optional[List[Dict[str, Any]]] = None, + meters: Optional[Dict[str, Any]] = None, + loss: Optional[Dict[str, Any]] = None, + ): + + self._setup_env_variables(env_variables) + self._setup_timers() + + self.data_conf = data + self.model_conf = model + self.logging_conf = LoggingConf(**logging) + self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing() + self.max_epochs = max_epochs + self.mode = mode + self.val_epoch_freq = val_epoch_freq + self.optim_conf = OptimConf(**optim) if optim is not None else None + self.meters_conf = meters + self.loss_conf = loss + distributed = DistributedConf(**distributed or {}) + cuda = CudaConf(**cuda or {}) + self.where = 0.0 + + self._infer_distributed_backend_if_none(distributed, accelerator) + + self._setup_device(accelerator) + + self._setup_torch_dist_and_backend(cuda, distributed) + + makedir(self.logging_conf.log_dir) + setup_logging( + __name__, + output_dir=self.logging_conf.log_dir, + rank=self.rank, + log_level_primary=self.logging_conf.log_level_primary, + log_level_secondary=self.logging_conf.log_level_secondary, + ) + + set_seeds(seed_value, self.max_epochs, self.distributed_rank) + log_env_variables() + + assert ( + is_dist_avail_and_initialized() + ), "Torch distributed needs to be initialized before calling the trainer." + + self._setup_components() # Except Optimizer everything is setup here. + self._move_to_device() + self._construct_optimizers() + self._setup_dataloaders() + + self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") + + if self.checkpoint_conf.resume_from is not None: + assert os.path.exists( + self.checkpoint_conf.resume_from + ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" + dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") + if self.distributed_rank == 0 and not os.path.exists(dst): + # Copy the "resume_from" checkpoint to the checkpoint folder + # if there is not a checkpoint to resume from already there + makedir(self.checkpoint_conf.save_dir) + g_pathmgr.copy(self.checkpoint_conf.resume_from, dst) + barrier() + + self.load_checkpoint() + self._setup_ddp_distributed_training(distributed, accelerator) + barrier() + + def _setup_timers(self): + """ + Initializes counters for elapsed time and eta. + """ + self.start_time = time.time() + self.ckpt_time_elapsed = 0 + self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0) + + def _get_meters(self, phase_filters=None): + if self.meters is None: + return {} + meters = {} + for phase, phase_meters in self.meters.items(): + if phase_filters is not None and phase not in phase_filters: + continue + for key, key_meters in phase_meters.items(): + if key_meters is None: + continue + for name, meter in key_meters.items(): + meters[f"{phase}_{key}/{name}"] = meter + return meters + + def _infer_distributed_backend_if_none(self, distributed_conf, accelerator): + if distributed_conf.backend is None: + distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo" + + def _setup_env_variables(self, env_variables_conf) -> None: + if env_variables_conf is not None: + for variable_name, value in env_variables_conf.items(): + os.environ[variable_name] = value + + def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None: + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic + torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = ( + cuda_conf.matmul_allow_tf32 + if cuda_conf.matmul_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + torch.backends.cudnn.allow_tf32 = ( + cuda_conf.cudnn_allow_tf32 + if cuda_conf.cudnn_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + + self.rank = setup_distributed_backend( + distributed_conf.backend, distributed_conf.timeout_mins + ) + + def _setup_device(self, accelerator): + self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() + if accelerator == "cuda": + self.device = torch.device("cuda", self.local_rank) + torch.cuda.set_device(self.local_rank) + elif accelerator == "cpu": + self.device = torch.device("cpu") + else: + raise ValueError(f"Unsupported accelerator: {accelerator}") + + def _setup_ddp_distributed_training(self, distributed_conf, accelerator): + + assert isinstance(self.model, torch.nn.Module) + + self.model = nn.parallel.DistributedDataParallel( + self.model, + device_ids=[self.local_rank] if accelerator == "cuda" else [], + find_unused_parameters=distributed_conf.find_unused_parameters, + ) + if distributed_conf.comms_dtype is not None: # noqa + from torch.distributed.algorithms import ddp_comm_hooks + + amp_type = get_amp_type(distributed_conf.comms_dtype) + if amp_type == torch.bfloat16: + hook = ddp_comm_hooks.default_hooks.bf16_compress_hook + logging.info("Enabling bfloat16 grad communication") + else: + hook = ddp_comm_hooks.default_hooks.fp16_compress_hook + logging.info("Enabling fp16 grad communication") + process_group = None + self.model.register_comm_hook(process_group, hook) + + def _move_to_device(self): + logging.info( + f"Moving components to device {self.device} and local rank {self.local_rank}." + ) + + self.model.to(self.device) + + logging.info( + f"Done moving components to device {self.device} and local rank {self.local_rank}." + ) + + def save_checkpoint(self, epoch, checkpoint_names=None): + checkpoint_folder = self.checkpoint_conf.save_dir + makedir(checkpoint_folder) + if checkpoint_names is None: + checkpoint_names = ["checkpoint"] + if ( + self.checkpoint_conf.save_freq > 0 + and (int(epoch) % self.checkpoint_conf.save_freq == 0) + ) or int(epoch) in self.checkpoint_conf.save_list: + checkpoint_names.append(f"checkpoint_{int(epoch)}") + + checkpoint_paths = [] + for ckpt_name in checkpoint_names: + checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt")) + + state_dict = unwrap_ddp_if_wrapped(self.model).state_dict() + state_dict = exclude_params_matching_unix_pattern( + patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict + ) + + checkpoint = { + "model": state_dict, + "optimizer": self.optim.optimizer.state_dict(), + "epoch": epoch, + "loss": self.loss.state_dict(), + "steps": self.steps, + "time_elapsed": self.time_elapsed_meter.val, + "best_meter_values": self.best_meter_values, + } + if self.optim_conf.amp.enabled: + checkpoint["scaler"] = self.scaler.state_dict() + + # DDP checkpoints are only saved on rank 0 (all workers are identical) + if self.distributed_rank != 0: + return + + for checkpoint_path in checkpoint_paths: + self._save_checkpoint(checkpoint, checkpoint_path) + + def _save_checkpoint(self, checkpoint, checkpoint_path): + """ + Save a checkpoint while guarding against the job being killed in the middle + of checkpoint saving (which corrupts the checkpoint file and ruins the + entire training since usually only the last checkpoint is kept per run). + + We first save the new checkpoint to a temp file (with a '.tmp' suffix), and + and move it to overwrite the old checkpoint_path. + """ + checkpoint_path_tmp = f"{checkpoint_path}.tmp" + with g_pathmgr.open(checkpoint_path_tmp, "wb") as f: + torch.save(checkpoint, f) + # after torch.save is completed, replace the old checkpoint with the new one + if g_pathmgr.exists(checkpoint_path): + # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails) + g_pathmgr.rm(checkpoint_path) + success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path) + assert success + + def load_checkpoint(self): + ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) + if ckpt_path is None: + self._init_model_state() + else: + if self.checkpoint_conf.initialize_after_preemption: + self._call_model_initializer() + self._load_resuming_checkpoint(ckpt_path) + + def _init_model_state(self): + # Checking that parameters that won't be saved are indeed frozen + # We do this check here before even saving the model to catch errors + # are early as possible and not at the end of the first epoch + assert_skipped_parameters_are_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + ) + + # Checking that parameters that won't be saved are initialized from + # within the model definition, unless `initialize_after_preemption` + # is explicitly set to `True`. If not, this is a bug, and after + # preemption, the `skip_saving_parameters` will have random values + allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption + with with_check_parameter_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + disabled=allow_init_skip_parameters, + ): + self._call_model_initializer() + + def _call_model_initializer(self): + model_weight_initializer = instantiate( + self.checkpoint_conf.model_weight_initializer + ) + if model_weight_initializer is not None: + logging.info( + f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}" + ) + self.model = model_weight_initializer(model=self.model) + + def _load_resuming_checkpoint(self, ckpt_path: str): + logging.info(f"Resuming training from {ckpt_path}") + + with g_pathmgr.open(ckpt_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + load_state_dict_into_model( + model=self.model, + state_dict=checkpoint["model"], + ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters, + ) + + self.optim.optimizer.load_state_dict(checkpoint["optimizer"]) + self.loss.load_state_dict(checkpoint["loss"], strict=True) + self.epoch = checkpoint["epoch"] + self.steps = checkpoint["steps"] + self.ckpt_time_elapsed = checkpoint.get("time_elapsed") + + if self.optim_conf.amp.enabled and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.best_meter_values = checkpoint.get("best_meter_values", {}) + + if "train_dataset" in checkpoint and self.train_dataset is not None: + self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"]) + + def is_intermediate_val_epoch(self, epoch): + return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1 + + def _step( + self, + batch: BatchedVideoDatapoint, + model: nn.Module, + phase: str, + ): + + outputs = model(batch) + targets = batch.masks + batch_size = len(batch.img_batch) + + key = batch.dict_key # key for dataset + loss = self.loss[key](outputs, targets) + loss_str = f"Losses/{phase}_{key}_loss" + + loss_log_str = os.path.join("Step_Losses", loss_str) + + # loss contains multiple sub-components we wish to log + step_losses = {} + if isinstance(loss, dict): + step_losses.update( + {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()} + ) + loss = self._log_loss_detailed_and_return_core_loss( + loss, loss_log_str, self.steps[phase] + ) + + if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + loss_log_str, + loss, + self.steps[phase], + ) + + self.steps[phase] += 1 + + ret_tuple = {loss_str: loss}, batch_size, step_losses + + if phase in self.meters and key in self.meters[phase]: + meters_dict = self.meters[phase][key] + if meters_dict is not None: + for _, meter in meters_dict.items(): + meter.update( + find_stages=outputs, + find_metadatas=batch.metadata, + ) + + return ret_tuple + + def run(self): + assert self.mode in ["train", "train_only", "val"] + if self.mode == "train": + if self.epoch > 0: + logging.info(f"Resuming training from epoch: {self.epoch}") + # resuming from a checkpoint + if self.is_intermediate_val_epoch(self.epoch - 1): + logging.info("Running previous val epoch") + self.epoch -= 1 + self.run_val() + self.epoch += 1 + self.run_train() + self.run_val() + elif self.mode == "val": + self.run_val() + elif self.mode == "train_only": + self.run_train() + + def _setup_dataloaders(self): + self.train_dataset = None + self.val_dataset = None + + if self.mode in ["train", "val"]: + self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None)) + + if self.mode in ["train", "train_only"]: + self.train_dataset = instantiate(self.data_conf.train) + + def run_train(self): + + while self.epoch < self.max_epochs: + dataloader = self.train_dataset.get_loader(epoch=int(self.epoch)) + barrier() + outs = self.train_epoch(dataloader) + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + # log train to text file. + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "train_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + # Save checkpoint before validating + self.save_checkpoint(self.epoch + 1) + + del dataloader + gc.collect() + + # Run val, not running on last epoch since will run after the + # loop anyway + if self.is_intermediate_val_epoch(self.epoch): + self.run_val() + + if self.distributed_rank == 0: + self.best_meter_values.update(self._get_trainer_state("train")) + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "best_stats.json"), + "a", + ) as f: + f.write(json.dumps(self.best_meter_values) + "\n") + + self.epoch += 1 + # epoch was incremented in the loop but the val step runs out of the loop + self.epoch -= 1 + + def run_val(self): + if not self.val_dataset: + return + + dataloader = self.val_dataset.get_loader(epoch=int(self.epoch)) + outs = self.val_epoch(dataloader, phase=Phase.VAL) + del dataloader + gc.collect() + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "val_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + def val_epoch(self, val_loader, phase): + batch_time = AverageMeter("Batch Time", self.device, ":.2f") + data_time = AverageMeter("Data Time", self.device, ":.2f") + mem = MemMeter("Mem (GB)", self.device, ":.2f") + + iters_per_epoch = len(val_loader) + + curr_phases = [phase] + curr_models = [self.model] + + loss_names = [] + for p in curr_phases: + for key in self.loss.keys(): + loss_names.append(f"Losses/{p}_{key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + for model in curr_models: + model.eval() + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_start() + + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()], + self._get_meters(curr_phases), + prefix="Val Epoch: [{}]".format(self.epoch), + ) + + end = time.time() + + for data_iter, batch in enumerate(val_loader): + + # measure data loading time + data_time.update(time.time() - end) + + batch = batch.to(self.device, non_blocking=True) + + # compute output + with torch.no_grad(): + with torch.cuda.amp.autocast( + enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), + dtype=( + get_amp_type(self.optim_conf.amp.amp_dtype) + if self.optim_conf + else None + ), + ): + for phase, model in zip(curr_phases, curr_models): + loss_dict, batch_size, extra_losses = self._step( + batch, + model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + loss_mts[loss_key].update(loss.item(), batch_size) + + for k, v in extra_losses.items(): + if k not in extra_loss_mts: + extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e") + extra_loss_mts[k].update(v.item(), batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + if torch.cuda.is_available(): + mem.update(reset_peak_usage=True) + + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[Phase.VAL], + ) + + if data_iter % 10 == 0: + dist.barrier() + + self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch + self._log_timers(phase) + for model in curr_models: + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_end() + + out_dict = self._log_meters_and_save_best_ckpts(curr_phases) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + + for phase in curr_phases: + out_dict.update(self._get_trainer_state(phase)) + self._reset_meters(curr_phases) + logging.info(f"Meters: {out_dict}") + return out_dict + + def _get_trainer_state(self, phase): + return { + "Trainer/where": self.where, + "Trainer/epoch": self.epoch, + f"Trainer/steps_{phase}": self.steps[phase], + } + + def train_epoch(self, train_loader): + + # Init stat meters + batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f") + data_time_meter = AverageMeter("Data Time", self.device, ":.2f") + mem_meter = MemMeter("Mem (GB)", self.device, ":.2f") + data_times = [] + phase = Phase.TRAIN + + iters_per_epoch = len(train_loader) + + loss_names = [] + for batch_key in self.loss.keys(): + loss_names.append(f"Losses/{phase}_{batch_key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + progress = ProgressMeter( + iters_per_epoch, + [ + batch_time_meter, + data_time_meter, + mem_meter, + self.time_elapsed_meter, + *loss_mts.values(), + ], + self._get_meters([phase]), + prefix="Train Epoch: [{}]".format(self.epoch), + ) + + # Model training loop + self.model.train() + end = time.time() + + for data_iter, batch in enumerate(train_loader): + # measure data loading time + data_time_meter.update(time.time() - end) + data_times.append(data_time_meter.val) + batch = batch.to( + self.device, non_blocking=True + ) # move tensors in a tensorclass + + try: + self._run_step(batch, phase, loss_mts, extra_loss_mts) + + # compute gradient and do optim step + exact_epoch = self.epoch + float(data_iter) / iters_per_epoch + self.where = float(exact_epoch) / self.max_epochs + assert self.where <= 1 + self.EPSILON + if self.where < 1.0: + self.optim.step_schedulers( + self.where, step=int(exact_epoch * iters_per_epoch) + ) + else: + logging.warning( + f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." + ) + + # Log schedulers + if data_iter % self.logging_conf.log_scalar_frequency == 0: + for j, param_group in enumerate(self.optim.optimizer.param_groups): + for option in self.optim.schedulers[j]: + optim_prefix = ( + "" + f"{j}_" + if len(self.optim.optimizer.param_groups) > 1 + else "" + ) + self.logger.log( + os.path.join("Optim", f"{optim_prefix}", option), + param_group[option], + self.steps[phase], + ) + + # Clipping gradients and detecting diverging gradients + if self.gradient_clipper is not None: + self.scaler.unscale_(self.optim.optimizer) + self.gradient_clipper(model=self.model) + + if self.gradient_logger is not None: + self.gradient_logger( + self.model, rank=self.distributed_rank, where=self.where + ) + + # Optimizer step: the scaler will make sure gradients are not + # applied if the gradients are infinite + self.scaler.step(self.optim.optimizer) + self.scaler.update() + + # measure elapsed time + batch_time_meter.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + mem_meter.update(reset_peak_usage=True) + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[phase], + ) + + # Catching NaN/Inf errors in the loss + except FloatingPointError as e: + raise e + + self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch + self._log_timers(Phase.TRAIN) + self._log_sync_data_times(Phase.TRAIN, data_times) + + out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN]) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + out_dict.update(self._get_trainer_state(phase)) + logging.info(f"Losses and meters: {out_dict}") + self._reset_meters([phase]) + return out_dict + + def _log_sync_data_times(self, phase, data_times): + data_times = all_reduce_max(torch.tensor(data_times)).tolist() + steps = range(self.steps[phase] - len(data_times), self.steps[phase]) + for step, data_time in zip(steps, data_times): + if step % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + os.path.join("Step_Stats", phase, "Data Time Synced"), + data_time, + step, + ) + + def _run_step( + self, + batch: BatchedVideoDatapoint, + phase: str, + loss_mts: Dict[str, AverageMeter], + extra_loss_mts: Dict[str, AverageMeter], + raise_on_error: bool = True, + ): + """ + Run the forward / backward + """ + + # it's important to set grads to None, especially with Adam since 0 + # grads will also update a model even if the step doesn't produce + # gradients + self.optim.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast( + enabled=self.optim_conf.amp.enabled, + dtype=get_amp_type(self.optim_conf.amp.amp_dtype), + ): + loss_dict, batch_size, extra_losses = self._step( + batch, + self.model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + if not math.isfinite(loss.item()): + error_msg = f"Loss is {loss.item()}, attempting to stop training" + logging.error(error_msg) + if raise_on_error: + raise FloatingPointError(error_msg) + else: + return + + self.scaler.scale(loss).backward() + loss_mts[loss_key].update(loss.item(), batch_size) + for extra_loss_key, extra_loss in extra_losses.items(): + if extra_loss_key not in extra_loss_mts: + extra_loss_mts[extra_loss_key] = AverageMeter( + extra_loss_key, self.device, ":.2e" + ) + extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size) + + def _log_meters_and_save_best_ckpts(self, phases: List[str]): + logging.info("Synchronizing meters") + out_dict = {} + checkpoint_save_keys = [] + for key, meter in self._get_meters(phases).items(): + meter_output = meter.compute_synced() + is_better_check = getattr(meter, "is_better", None) + + for meter_subkey, meter_value in meter_output.items(): + out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value + + if is_better_check is None: + continue + + tracked_meter_key = os.path.join(key, meter_subkey) + if tracked_meter_key not in self.best_meter_values or is_better_check( + meter_value, + self.best_meter_values[tracked_meter_key], + ): + self.best_meter_values[tracked_meter_key] = meter_value + + if ( + self.checkpoint_conf.save_best_meters is not None + and key in self.checkpoint_conf.save_best_meters + ): + checkpoint_save_keys.append(tracked_meter_key.replace("/", "_")) + + if len(checkpoint_save_keys) > 0: + self.save_checkpoint(self.epoch + 1, checkpoint_save_keys) + + return out_dict + + def _log_timers(self, phase): + time_remaining = 0 + epochs_remaining = self.max_epochs - self.epoch - 1 + val_epochs_remaining = sum( + n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs) + ) + + # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with + # the end epoch. + if (self.max_epochs - 1) % self.val_epoch_freq != 0: + val_epochs_remaining += 1 + + # Remove the current val run from estimate + if phase == Phase.VAL: + val_epochs_remaining -= 1 + + time_remaining += ( + epochs_remaining * self.est_epoch_time[Phase.TRAIN] + + val_epochs_remaining * self.est_epoch_time[Phase.VAL] + ) + + self.logger.log( + os.path.join("Step_Stats", phase, self.time_elapsed_meter.name), + self.time_elapsed_meter.val, + self.steps[phase], + ) + + logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}") + + def _reset_meters(self, phases: str) -> None: + for meter in self._get_meters(phases).values(): + meter.reset() + + def _check_val_key_match(self, val_keys, phase): + if val_keys is not None: + # Check if there are any duplicates + assert len(val_keys) == len( + set(val_keys) + ), f"Duplicate keys in val datasets, keys: {val_keys}" + + # Check that the keys match the meter keys + if self.meters_conf is not None and phase in self.meters_conf: + assert set(val_keys) == set(self.meters_conf[phase].keys()), ( + f"Keys in val datasets do not match the keys in meters." + f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}" + f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}" + ) + + if self.loss_conf is not None: + loss_keys = set(self.loss_conf.keys()) - set(["all"]) + assert all([k in loss_keys for k in val_keys]), ( + f"Keys in val datasets do not match the keys in losses." + f"\nMissing in losses: {set(val_keys) - loss_keys}" + f"\nMissing in val datasets: {loss_keys - set(val_keys)}" + ) + + def _setup_components(self): + + # Get the keys for all the val datasets, if any + val_phase = Phase.VAL + val_keys = None + if self.data_conf.get(val_phase, None) is not None: + val_keys = collect_dict_keys(self.data_conf[val_phase]) + # Additional checks on the sanity of the config for val datasets + self._check_val_key_match(val_keys, phase=val_phase) + + logging.info("Setting up components: Model, loss, optim, meters etc.") + self.epoch = 0 + self.steps = {Phase.TRAIN: 0, Phase.VAL: 0} + + self.logger = Logger(self.logging_conf) + + self.model = instantiate(self.model_conf, _convert_="all") + print_model_summary(self.model) + + self.loss = None + if self.loss_conf: + self.loss = { + key: el # wrap_base_loss(el) + for (key, el) in instantiate(self.loss_conf, _convert_="all").items() + } + self.loss = nn.ModuleDict(self.loss) + + self.meters = {} + self.best_meter_values = {} + if self.meters_conf: + self.meters = instantiate(self.meters_conf, _convert_="all") + + self.scaler = torch.amp.GradScaler( + self.device, + enabled=self.optim_conf.amp.enabled if self.optim_conf else False, + ) + + self.gradient_clipper = ( + instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None + ) + self.gradient_logger = ( + instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None + ) + + logging.info("Finished setting up components: Model, loss, optim, meters etc.") + + def _construct_optimizers(self): + self.optim = construct_optimizer( + self.model, + self.optim_conf.optimizer, + self.optim_conf.options, + self.optim_conf.param_group_modifiers, + ) + + def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step): + core_loss = loss.pop(CORE_LOSS_KEY) + if step % self.logging_conf.log_scalar_frequency == 0: + for k in loss: + log_str = os.path.join(loss_str, k) + self.logger.log(log_str, loss[k], step) + return core_loss + + +def print_model_summary(model: torch.nn.Module, log_dir: str = ""): + """ + Prints the model and the number of parameters in the model. + # Multiple packages provide this info in a nice table format + # However, they need us to provide an `input` (as they also write down the output sizes) + # Our models are complex, and a single input is restrictive. + # https://github.com/sksq96/pytorch-summary + # https://github.com/nmhkahn/torchsummaryX + """ + if get_rank() != 0: + return + param_kwargs = {} + trainable_parameters = sum( + p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad + ) + total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs)) + non_trainable_parameters = total_parameters - trainable_parameters + logging.info("==" * 10) + logging.info(f"Summary for model {type(model)}") + logging.info(f"Model is {model}") + logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}") + logging.info( + f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}" + ) + logging.info( + f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}" + ) + logging.info("==" * 10) + + if log_dir: + output_fpath = os.path.join(log_dir, "model.txt") + with g_pathmgr.open(output_fpath, "w") as f: + print(model, file=f) + + +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] + + +def get_human_readable_count(number: int) -> str: + """ + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1.2 K' + >>> get_human_readable_count(2e6) # (two million) + '2.0 M' + >>> get_human_readable_count(3e9) # (three billion) + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = PARAMETER_NUM_UNITS + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions + shift = -3 * (num_groups - 1) + number = number * (10**shift) + index = num_groups - 1 + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + else: + return f"{number:,.1f} {labels[index]}" diff --git a/third_party/sam2/training/utils/__init__.py b/third_party/sam2/training/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/training/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/training/utils/checkpoint_utils.py b/third_party/sam2/training/utils/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f76689f341dedc485c0c32d096fb5b2e8337bea9 --- /dev/null +++ b/third_party/sam2/training/utils/checkpoint_utils.py @@ -0,0 +1,361 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import fnmatch +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import numpy as np +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from torch.jit._script import RecursiveScriptModule + + +def unix_pattern_to_parameter_names( + constraints: List[str], all_parameter_names: Sequence[str] +) -> Union[None, Set[str]]: + """ + Go through the list of parameter names and select those that match + any of the provided constraints + """ + parameter_names = [] + for param_name in constraints: + matching_parameters = set(fnmatch.filter(all_parameter_names, param_name)) + assert ( + len(matching_parameters) > 0 + ), f"param_names {param_name} don't match any param in the given names." + parameter_names.append(matching_parameters) + return set.union(*parameter_names) + + +def filter_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return {} + + all_keys = list(state_dict.keys()) + included_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: state_dict[k] for k in included_keys} + + +def exclude_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return state_dict + + all_keys = list(state_dict.keys()) + excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: v for k, v in state_dict.items() if k not in excluded_keys} + + +def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]): + keys = [] + trace = [] + for k, v in state_dict.items(): + keys.append(k) + trace.append(v.sum().item()) + trace = np.array(trace)[np.argsort(keys)] + return trace + + +def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]): + """ + Verifies that all the parameters matching the provided patterns + are frozen - this acts as a safeguard when ignoring parameter + when saving checkpoints - if the parameters are in fact trainable + """ + if not patterns: + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + non_frozen_keys = { + n + for n, p in model.named_parameters() + if n in frozen_state_dict and p.requires_grad + } + if non_frozen_keys: + raise ValueError( + f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}" + ) + + +@contextlib.contextmanager +def with_check_parameter_frozen( + model: nn.Module, patterns: List[str], disabled: bool = True +): + """ + Context manager that inspects a model surrounding a piece of code + and verifies if the model has been updated by this piece of code + + The function will raise an exception if the model has been updated + on at least one of the parameter that matches one of the pattern + + Args: + model: the model that might have been updated + patterns: for the parameters we want to observe + allowed: + """ + if not patterns or disabled: + yield + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_before = _get_state_dict_summary(frozen_state_dict) + + yield + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_after = _get_state_dict_summary(frozen_state_dict) + + if not np.allclose(summary_before, summary_after, atol=1e-6): + raise ValueError( + f""" + The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`. + You can resolve this error by either initializing those parameters from within the model definition + or using the flag `trainer.checkpoint.initialize_after_preemption` to True. + """ + ) + + +class CkptExcludeKernel: + """ + Removes the keys from the given model state_dict that match the key_pattern. + + Args: + key_pattern: Patterns used to select the keys in the state_dict + that are eligible for this kernel. + """ + + def __init__(self, key_pattern: List[str]): + self.key_pattern = key_pattern + + def __call__(self, state_dict: Dict): + """ + Args: + state_dict: A dictionary representing the given checkpoint's state dict. + """ + if len(self.key_pattern) == 0: + return state_dict + exclude_keys = unix_pattern_to_parameter_names( + self.key_pattern, state_dict.keys() + ) + return {k: v for k, v in state_dict.items() if k not in exclude_keys} + + +def load_checkpoint( + path_list: List[str], + pick_recursive_keys: Optional[List[str]] = None, + map_location: str = "cpu", +) -> Any: + """ + Loads a checkpoint from the specified path. + + Args: + path_list: A list of paths which contain the checkpoint. Each element + is tried (in order) until a file that exists is found. That file is then + used to read the checkpoint. + pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None. + For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"] + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + path_exists = False + for path in path_list: + if g_pathmgr.exists(path): + path_exists = True + break + + if not path_exists: + raise ValueError(f"No path exists in {path_list}") + + with g_pathmgr.open(path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + logging.info(f"Loaded checkpoint from {path}") + if pick_recursive_keys is not None: + for key in pick_recursive_keys: + checkpoint = checkpoint[key] + return checkpoint + + +def get_state_dict(checkpoint, ckpt_state_dict_keys): + if isinstance(checkpoint, RecursiveScriptModule): + # This is a torchscript JIT model + return checkpoint.state_dict() + pre_train_dict = checkpoint + for i, key in enumerate(ckpt_state_dict_keys): + if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or ( + isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict) + ): + key_str = ( + '["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]' + ) + raise KeyError( + f"'{key}' not found in checkpoint{key_str} " + f"with keys: {pre_train_dict.keys()}" + ) + pre_train_dict = pre_train_dict[key] + return pre_train_dict + + +def load_checkpoint_and_apply_kernels( + checkpoint_path: str, + checkpoint_kernels: List[Callable] = None, + ckpt_state_dict_keys: Tuple[str] = ("state_dict",), + map_location: str = "cpu", +) -> nn.Module: + """ + Performs checkpoint loading with a variety of pre-processing kernel applied in + sequence. + + Args: + checkpoint_path (str): Path to the checkpoint. + checkpoint_kernels List(Callable): A list of checkpoint processing kernels + to apply in the specified order. Supported kernels include `CkptIncludeKernel`, + `CkptExcludeKernel`, etc. These kernels are applied in the + given order. + ckpt_state_dict_keys (str): Keys containing the model state dict. + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format( + checkpoint_path + ) + + # Load the checkpoint on CPU to avoid GPU mem spike. + with g_pathmgr.open(checkpoint_path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys) + + # Not logging into info etc since it's a huge log + logging.debug( + "Loaded Checkpoint State Dict pre-kernel application: %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + pre_train_dict = f(state_dict=pre_train_dict) + + logging.debug( + "Loaded Checkpoint State Dict Post-kernel application %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + + return pre_train_dict + + +def check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict: bool, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, +): + if ignore_missing_keys is not None and len(ignore_missing_keys) > 0: + ignored_keys = unix_pattern_to_parameter_names( + ignore_missing_keys, missing_keys + ) + missing_keys = [key for key in missing_keys if key not in ignored_keys] + + if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0: + ignored_unexpected_keys = unix_pattern_to_parameter_names( + ignore_unexpected_keys, unexpected_keys + ) + unexpected_keys = [ + key for key in unexpected_keys if key not in ignored_unexpected_keys + ] + + err = "State key mismatch." + if unexpected_keys: + err += f" Unexpected keys: {unexpected_keys}." + if missing_keys: + err += f" Missing keys: {missing_keys}." + + if unexpected_keys or missing_keys: + logging.warning(err) + if unexpected_keys or strict: + raise KeyError(err) + + +def load_state_dict_into_model( + state_dict: Dict, + model: nn.Module, + strict: bool = True, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, + checkpoint_kernels: List[Callable] = None, +): + """ + Loads a state dict into the given model. + + Args: + state_dict: A dictionary containing the model's + state dict, or a subset if strict is False + model: Model to load the checkpoint weights into + strict: raise if the state_dict has missing state keys + ignore_missing_keys: unix pattern of keys to ignore + """ + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + state_dict = f(state_dict=state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict=strict, + ignore_missing_keys=ignore_missing_keys, + ignore_unexpected_keys=ignore_unexpected_keys, + ) + return model diff --git a/third_party/sam2/training/utils/data_utils.py b/third_party/sam2/training/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd0115355c97a27c601a833985466e558063b91 --- /dev/null +++ b/third_party/sam2/training/utils/data_utils.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from PIL import Image as PILImage +from tensordict import tensorclass + + +@tensorclass +class BatchedVideoMetaData: + """ + This class represents metadata about a batch of videos. + Attributes: + unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) + frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. + """ + + unique_objects_identifier: torch.LongTensor + frame_orig_size: torch.LongTensor + + +@tensorclass +class BatchedVideoDatapoint: + """ + This class represents a batch of videos with associated annotations and metadata. + Attributes: + img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. + obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. + masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. + metadata: An instance of BatchedVideoMetaData containing metadata about the batch. + dict_key: A string key used to identify the batch. + """ + + img_batch: torch.FloatTensor + obj_to_frame_idx: torch.IntTensor + masks: torch.BoolTensor + metadata: BatchedVideoMetaData + + dict_key: str + + def pin_memory(self, device=None): + return self.apply(torch.Tensor.pin_memory, device=device) + + @property + def num_frames(self) -> int: + """ + Returns the number of frames per video. + """ + return self.batch_size[0] + + @property + def num_videos(self) -> int: + """ + Returns the number of videos in the batch. + """ + return self.img_batch.shape[1] + + @property + def flat_obj_to_img_idx(self) -> torch.IntTensor: + """ + Returns a flattened tensor containing the object to img index. + The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] + """ + frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) + flat_idx = video_idx * self.num_frames + frame_idx + return flat_idx + + @property + def flat_img_batch(self) -> torch.FloatTensor: + """ + Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] + """ + + return self.img_batch.transpose(0, 1).flatten(0, 1) + + +@dataclass +class Object: + # Id of the object in the media + object_id: int + # Index of the frame in the media (0 if single image) + frame_index: int + segment: Union[torch.Tensor, dict] # RLE dict or binary mask + + +@dataclass +class Frame: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + + +@dataclass +class VideoDatapoint: + """Refers to an image/video and all its annotations""" + + frames: List[Frame] + video_id: int + size: Tuple[int, int] + + +def collate_fn( + batch: List[VideoDatapoint], + dict_key, +) -> BatchedVideoDatapoint: + """ + Args: + batch: A list of VideoDatapoint instances. + dict_key (str): A string key used to identify the batch. + """ + img_batch = [] + for video in batch: + img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] + + img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) + T = img_batch.shape[0] + # Prepare data structures for sequential processing. Per-frame processing but batched across videos. + step_t_objects_identifier = [[] for _ in range(T)] + step_t_frame_orig_size = [[] for _ in range(T)] + + step_t_masks = [[] for _ in range(T)] + step_t_obj_to_frame_idx = [ + [] for _ in range(T) + ] # List to store frame indices for each time step + + for video_idx, video in enumerate(batch): + orig_video_id = video.video_id + orig_frame_size = video.size + for t, frame in enumerate(video.frames): + objects = frame.objects + for obj in objects: + orig_obj_id = obj.object_id + orig_frame_idx = obj.frame_index + step_t_obj_to_frame_idx[t].append( + torch.tensor([t, video_idx], dtype=torch.int) + ) + step_t_masks[t].append(obj.segment.to(torch.bool)) + step_t_objects_identifier[t].append( + torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) + ) + step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) + + obj_to_frame_idx = torch.stack( + [ + torch.stack(obj_to_frame_idx, dim=0) + for obj_to_frame_idx in step_t_obj_to_frame_idx + ], + dim=0, + ) + masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) + objects_identifier = torch.stack( + [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 + ) + frame_orig_size = torch.stack( + [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 + ) + return BatchedVideoDatapoint( + img_batch=img_batch, + obj_to_frame_idx=obj_to_frame_idx, + masks=masks, + metadata=BatchedVideoMetaData( + unique_objects_identifier=objects_identifier, + frame_orig_size=frame_orig_size, + ), + dict_key=dict_key, + batch_size=[T], + ) diff --git a/third_party/sam2/training/utils/distributed.py b/third_party/sam2/training/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..f614b40427f40350c4df9e695cd327cb4d6a96f6 --- /dev/null +++ b/third_party/sam2/training/utils/distributed.py @@ -0,0 +1,576 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import functools +import io +import logging +import os +import random +import tempfile +import time +from typing import Any, Callable, List, Tuple + +import torch +import torch.autograd as autograd +import torch.distributed as dist + + +# Default to GPU 0 +_cuda_device_index: int = 0 + +# Setting _cuda_device_index to -1 internally implies that we should use CPU +_CPU_DEVICE_INDEX = -1 +_PRIMARY_RANK = 0 + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes + # being much slower than others causing a timeout (which can happen in relation + # or LVIS class mAP evaluation). + timeout = 43200 + return dist.new_group( + backend="gloo", + timeout=datetime.timedelta(seconds=timeout), + ) + + return dist.group.WORLD + + +def is_main_process(): + """Return true if the current process is the main one""" + return get_rank() == 0 + + +def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors), similar to + `all_gather` above, but using filesystem instead of collective ops. + + If gather_to_rank_0_only is True, only rank 0 will load the gathered object list + (and other ranks will have an empty list). + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + print("gathering via files") + cpu_group = _get_global_gloo_group() + + # if unspecified, we will save to the current python file dir + if filesys_save_dir is not None: + save_dir = filesys_save_dir + elif "EXP_DIR" in os.environ: + save_dir = os.environ["EXP_DIR"] + else: + # try the same directory where the code is stored + save_dir = filesys_save_dir or os.path.dirname(__file__) + save_dir = os.path.join(save_dir, "all_gather_via_filesys") + if is_main_process(): + os.makedirs(save_dir, exist_ok=True) + + # use a timestamp and salt to distinguish different all_gather + timestamp = int(time.time()) if is_main_process() else 0 + salt = random.randint(0, 2**31 - 1) if is_main_process() else 0 + # broadcast the timestamp and salt across ranks + # (all-reduce will do the broadcasting since only rank 0 is non-zero) + timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long) + dist.all_reduce(timestamp_and_salt, group=cpu_group) + timestamp, salt = timestamp_and_salt.tolist() + + # save the data to a file on the disk + rank_save = get_rank() + save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl" + save_data_path = os.path.join(save_dir, save_data_filename) + assert not os.path.exists(save_data_path), f"{save_data_path} already exists" + torch.save(data, save_data_path) + dist.barrier(group=cpu_group) + + # read the data from the files + data_list = [] + if rank_save == 0 or not gather_to_rank_0_only: + for rank_load in range(world_size): + load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl" + load_data_path = os.path.join(save_dir, load_data_filename) + assert os.path.exists(load_data_path), f"cannot read {save_data_path}" + data_list.append(torch.load(load_data_path)) + dist.barrier(group=cpu_group) + + # delete the saved file + os.remove(save_data_path) + return data_list + + +def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1": + return all_gather_via_filesys( + data, filesys_save_dir, gather_to_rank_0_only=True + ) + + if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys: + return all_gather_via_filesys(data, filesys_save_dir) + + cpu_group = None + if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu: + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [ + torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size) + ] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size,), dtype=torch.uint8, device=device + ) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer) + data_list.append(obj) + + return data_list + + +def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This helper function converts to the correct + device and returns the tensor + original device. + """ + orig_device = "cpu" if not tensor.is_cuda else "gpu" + if ( + torch.distributed.is_available() + and torch.distributed.get_backend() == torch.distributed.Backend.NCCL + and not tensor.is_cuda + ): + tensor = tensor.cuda() + return (tensor, orig_device) + + +def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This converts the tensor back to original device. + """ + if tensor.is_cuda and orig_device == "cpu": + tensor = tensor.cpu() + return tensor + + +def is_distributed_training_run() -> bool: + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and (torch.distributed.get_world_size() > 1) + ) + + +def is_primary() -> bool: + """ + Returns True if this is rank 0 of a distributed training job OR if it is + a single trainer job. Otherwise False. + """ + return get_rank() == _PRIMARY_RANK + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing mean reduction + of tensor over all processes. + """ + return all_reduce_op( + tensor, + torch.distributed.ReduceOp.SUM, + lambda t: t / torch.distributed.get_world_size(), + ) + + +def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing sum + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM) + + +def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN) + + +def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX) + + +def all_reduce_op( + tensor: torch.Tensor, + op: torch.distributed.ReduceOp, + after_op_func: Callable[[torch.Tensor], torch.Tensor] = None, +) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.all_reduce(tensor, op) + if after_op_func is not None: + tensor = after_op_func(tensor) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Wrapper over torch.distributed.all_gather for performing + 'gather' of 'tensor' over all processes in both distributed / + non-distributed scenarios. + """ + if tensor.ndim == 0: + # 0 dim tensors cannot be gathered. so unsqueeze + tensor = tensor.unsqueeze(0) + + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + gathered_tensors = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(gathered_tensors, tensor) + gathered_tensors = [ + convert_to_normal_tensor(_tensor, orig_device) + for _tensor in gathered_tensors + ] + else: + gathered_tensors = [tensor] + + return gathered_tensors + + +def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: + gathered_tensors = gather_tensors_from_all(tensor) + gathered_tensor = torch.cat(gathered_tensors, 0) + return gathered_tensor + + +def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """ + Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source + to all processes in both distributed / non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.broadcast(tensor, src) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def barrier() -> None: + """ + Wrapper over torch.distributed.barrier, returns without waiting + if the distributed process group is not initialized instead of throwing error. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return + torch.distributed.barrier() + + +def get_world_size() -> int: + """ + Simple wrapper for correctly getting worldsize in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_world_size() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 1 + ) + + +def get_rank() -> int: + """ + Simple wrapper for correctly getting rank in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_rank() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 0 + ) + + +def get_primary_rank() -> int: + return _PRIMARY_RANK + + +def set_cuda_device_index(idx: int) -> None: + global _cuda_device_index + _cuda_device_index = idx + torch.cuda.set_device(_cuda_device_index) + + +def set_cpu_device() -> None: + global _cuda_device_index + _cuda_device_index = _CPU_DEVICE_INDEX + + +def get_cuda_device_index() -> int: + return _cuda_device_index + + +def init_distributed_data_parallel_model( + model: torch.nn.Module, + broadcast_buffers: bool = False, + find_unused_parameters: bool = True, + bucket_cap_mb: int = 25, +) -> torch.nn.parallel.DistributedDataParallel: + global _cuda_device_index + + if _cuda_device_index == _CPU_DEVICE_INDEX: + # CPU-only model, don't specify device + return torch.nn.parallel.DistributedDataParallel( + model, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + else: + # GPU model + return torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[_cuda_device_index], + output_device=_cuda_device_index, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + + +def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any: + """Broadcast an object from a source to all workers. + + Args: + obj: Object to broadcast, must be serializable + src: Source rank for broadcast (default is primary) + use_disk: If enabled, removes redundant CPU memory copies by writing to + disk + """ + # Either broadcast from primary to the fleet (default), + # or use the src setting as the original rank + if get_rank() == src: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + data_view = buffer.getbuffer() + length_tensor = torch.LongTensor([len(data_view)]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.ByteTensor(data_view) + data_tensor = broadcast(data_tensor, src=src) + else: + # Fetch from the source + length_tensor = torch.LongTensor([0]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8) + data_tensor = broadcast(data_tensor, src=src) + if use_disk: + with tempfile.TemporaryFile("r+b") as f: + f.write(data_tensor.numpy()) + # remove reference to the data tensor and hope that Python garbage + # collects it + del data_tensor + f.seek(0) + obj = torch.load(f) + else: + buffer = io.BytesIO(data_tensor.numpy()) + obj = torch.load(buffer) + return obj + + +def all_gather_tensor(tensor: torch.Tensor, world_size=None): + if world_size is None: + world_size = get_world_size() + # make contiguous because NCCL won't gather the tensor otherwise + assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!" + tensor, orig_device = convert_to_distributed_tensor(tensor) + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_all, tensor, async_op=False) # performance opt + tensor_all = [ + convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all + ] + return tensor_all + + +def all_gather_batch(tensors: List[torch.Tensor]): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = all_gather_tensor(tensor, world_size) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +class GatherLayer(autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +def all_gather_batch_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + + for tensor in tensors: + tensor_all = GatherLayer.apply(tensor) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +def create_new_process_group(group_size): + """ + Creates process groups of a gives `group_size` and returns + process group that current GPU participates in. + + `group_size` must divide the total number of GPUs (world_size). + + Modified from + https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60 + + Args: + group_size (int): number of GPU's to collaborate for sync bn + """ + + assert group_size > 0 + + world_size = torch.distributed.get_world_size() + if world_size <= 8: + if group_size > world_size: + logging.warning( + f"Requested group size [{group_size}] > world size [{world_size}]. " + "Assuming local debug run and capping it to world size." + ) + group_size = world_size + assert world_size >= group_size + assert world_size % group_size == 0 + + group = None + for group_num in range(world_size // group_size): + group_ids = range(group_num * group_size, (group_num + 1) * group_size) + cur_group = torch.distributed.new_group(ranks=group_ids) + if torch.distributed.get_rank() // group_size == group_num: + group = cur_group + # can not drop out and return here, every process must go through creation of all subgroups + + assert group is not None + return group + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/third_party/sam2/training/utils/logger.py b/third_party/sam2/training/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b4ef0ebe359063e1ca2c3a46cb8fcc76d067c2 --- /dev/null +++ b/third_party/sam2/training/utils/logger.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py +import atexit +import functools +import logging +import sys +import uuid +from typing import Any, Dict, Optional, Union + +from hydra.utils import instantiate + +from iopath.common.file_io import g_pathmgr +from numpy import ndarray +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter + +from training.utils.train_utils import get_machine_local_and_dist_rank, makedir + +Scalar = Union[Tensor, ndarray, int, float] + + +def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): + makedir(log_dir) + summary_writer_method = SummaryWriter + return TensorBoardLogger( + path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs + ) + + +class TensorBoardWriterWrapper: + """ + A wrapper around a SummaryWriter object. + """ + + def __init__( + self, + path: str, + *args: Any, + filename_suffix: str = None, + summary_writer_method: Any = SummaryWriter, + **kwargs: Any, + ) -> None: + """Create a new TensorBoard logger. + On construction, the logger creates a new events file that logs + will be written to. If the environment variable `RANK` is defined, + logger will only log if RANK = 0. + + NOTE: If using the logger with distributed training: + - This logger can call collective operations + - Logs will be written on rank 0 only + - Logger must be constructed synchronously *after* initializing distributed process group. + + Args: + path (str): path to write logs to + *args, **kwargs: Extra arguments to pass to SummaryWriter + """ + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_machine_local_and_dist_rank() + self._path: str = path + if self._rank == 0: + logging.info( + f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" + ) + self._writer = summary_writer_method( + log_dir=path, + *args, + filename_suffix=filename_suffix or str(uuid.uuid4()), + **kwargs, + ) + else: + logging.debug( + f"Not logging meters on this host because env RANK: {self._rank} != 0" + ) + atexit.register(self.close) + + @property + def writer(self) -> Optional[SummaryWriter]: + return self._writer + + @property + def path(self) -> str: + return self._path + + def flush(self) -> None: + """Writes pending logs to disk.""" + + if not self._writer: + return + + self._writer.flush() + + def close(self) -> None: + """Close writer, flushing pending logs to disk. + Logs cannot be written after `close` is called. + """ + + if not self._writer: + return + + self._writer.close() + self._writer = None + + +class TensorBoardLogger(TensorBoardWriterWrapper): + """ + A simple logger for TensorBoard. + """ + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + """Add multiple scalar values to TensorBoard. + + Args: + payload (dict): dictionary of tag name and scalar value + step (int, Optional): step value to record + """ + if not self._writer: + return + for k, v in payload.items(): + self.log(k, v, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + """Add scalar data to TensorBoard. + + Args: + name (string): tag name used to group scalars + data (float/int/Tensor): scalar data to log + step (int, optional): step value to record + """ + if not self._writer: + return + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + """Add hyperparameter data to TensorBoard. + + Args: + hparams (dict): dictionary of hyperparameter names and corresponding values + meters (dict): dictionary of name of meter and corersponding values + """ + if not self._writer: + return + self._writer.add_hparams(hparams, meters) + + +class Logger: + """ + A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. + """ + + def __init__(self, logging_conf): + # allow turning off TensorBoard with "should_log: false" in config + tb_config = logging_conf.tensorboard_writer + tb_should_log = tb_config and tb_config.pop("should_log", True) + self.tb_logger = instantiate(tb_config) if tb_should_log else None + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + if self.tb_logger: + self.tb_logger.log_dict(payload, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self.tb_logger: + self.tb_logger.log(name, data, step) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + if self.tb_logger: + self.tb_logger.log_hparams(hparams, meters) + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # we tune the buffering value so that the logs are updated + # frequently. + log_buffer_kb = 10 * 1024 # 10KB + io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) + atexit.register(io.close) + return io + + +def setup_logging( + name, + output_dir=None, + rank=0, + log_level_primary="INFO", + log_level_secondary="ERROR", +): + """ + Setup various logging streams: stdout and file handlers. + For file handlers, we only setup for the master gpu. + """ + # get the filename if we want to log to the file as well + log_filename = None + if output_dir: + makedir(output_dir) + if rank == 0: + log_filename = f"{output_dir}/log.txt" + + logger = logging.getLogger(name) + logger.setLevel(log_level_primary) + + # create formatter + FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" + formatter = logging.Formatter(FORMAT) + + # Cleanup any existing handlers + for h in logger.handlers: + logger.removeHandler(h) + logger.root.handlers = [] + + # setup the console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + if rank == 0: + console_handler.setLevel(log_level_primary) + else: + console_handler.setLevel(log_level_secondary) + + # we log to file as well if user wants + if log_filename and rank == 0: + file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) + file_handler.setLevel(log_level_primary) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logging.root = logger + + +def shutdown_logging(): + """ + After training is done, we ensure to shut down all the logger streams. + """ + logging.info("Shutting down loggers...") + handlers = logging.root.handlers + for handler in handlers: + handler.close() diff --git a/third_party/sam2/training/utils/train_utils.py b/third_party/sam2/training/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91d5577d5f50c81624737d221dc572ac3c4cee56 --- /dev/null +++ b/third_party/sam2/training/utils/train_utils.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os +import random +import re +from datetime import timedelta +from typing import Optional + +import hydra + +import numpy as np +import omegaconf +import torch +import torch.distributed as dist +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf + + +def multiply_all(*args): + return np.prod(np.array(args)).item() + + +def collect_dict_keys(config): + """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" + val_keys = [] + # If the this config points to the collate function, then it has a key + if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): + val_keys.append(config["dict_key"]) + else: + # Recursively proceed + for v in config.values(): + if isinstance(v, type(config)): + val_keys.extend(collect_dict_keys(v)) + elif isinstance(v, omegaconf.listconfig.ListConfig): + for item in v: + if isinstance(item, type(config)): + val_keys.extend(collect_dict_keys(item)) + return val_keys + + +class Phase: + TRAIN = "train" + VAL = "val" + + +def register_omegaconf_resolvers(): + OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) + OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) + OmegaConf.register_new_resolver("add", lambda x, y: x + y) + OmegaConf.register_new_resolver("times", multiply_all) + OmegaConf.register_new_resolver("divide", lambda x, y: x / y) + OmegaConf.register_new_resolver("pow", lambda x, y: x**y) + OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) + OmegaConf.register_new_resolver("range", lambda x: list(range(x))) + OmegaConf.register_new_resolver("int", lambda x: int(x)) + OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) + OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) + + +def setup_distributed_backend(backend, timeout_mins): + """ + Initialize torch.distributed and set the CUDA device. + Expects environment variables to be set as per + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. + """ + # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins + # of waiting + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") + dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) + return dist.get_rank() + + +def get_machine_local_and_dist_rank(): + """ + Get the distributed and local rank of the current gpu. + """ + local_rank = int(os.environ.get("LOCAL_RANK", None)) + distributed_rank = int(os.environ.get("RANK", None)) + assert ( + local_rank is not None and distributed_rank is not None + ), "Please the set the RANK and LOCAL_RANK environment variables." + return local_rank, distributed_rank + + +def print_cfg(cfg): + """ + Supports printing both Hydra DictConfig and also the AttrDict config + """ + logging.info("Training with config:") + logging.info(OmegaConf.to_yaml(cfg)) + + +def set_seeds(seed_value, max_epochs, dist_rank): + """ + Set the python random, numpy and torch seed for each gpu. Also set the CUDA + seeds if the CUDA is available. This ensures deterministic nature of the training. + """ + # Since in the pytorch sampler, we increment the seed by 1 for every epoch. + seed_value = (seed_value + dist_rank) * max_epochs + logging.info(f"MACHINE SEED: {seed_value}") + random.seed(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed_value) + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_amp_type(amp_type: Optional[str] = None): + if amp_type is None: + return None + assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." + if amp_type == "bfloat16": + return torch.bfloat16 + else: + return torch.float16 + + +def log_env_variables(): + env_keys = sorted(list(os.environ.keys())) + st = "" + for k in env_keys: + v = os.environ[k] + st += f"{k}={v}\n" + logging.info("Logging ENV_VARIABLES") + logging.info(st) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class MemMeter: + """Computes and stores the current, avg, and max of peak Mem usage per iteration""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 # Per iteration max usage + self.avg = 0 # Avg per iteration max usage + self.peak = 0 # Peak usage for lifetime of program + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, n=1, reset_peak_usage=True): + self.val = torch.cuda.max_memory_allocated() // 1e9 + self.sum += self.val * n + self.count += n + self.avg = self.sum / self.count + self.peak = max(self.peak, self.val) + if reset_peak_usage: + torch.cuda.reset_peak_memory_stats() + + def __str__(self): + fmtstr = ( + "{name}: {val" + + self.fmt + + "} ({avg" + + self.fmt + + "}/{peak" + + self.fmt + + "})" + ) + return fmtstr.format(**self.__dict__) + + +def human_readable_time(time_seconds): + time = int(time_seconds) + minutes, seconds = divmod(time, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{days:02}d {hours:02}h {minutes:02}m" + + +class DurationMeter: + def __init__(self, name, device, fmt=":f"): + self.name = name + self.device = device + self.fmt = fmt + self.val = 0 + + def reset(self): + self.val = 0 + + def update(self, val): + self.val = val + + def add(self, val): + self.val += val + + def __str__(self): + return f"{self.name}: {human_readable_time(self.val)}" + + +class ProgressMeter: + def __init__(self, num_batches, meters, real_meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.real_meters = real_meters + self.prefix = prefix + + def display(self, batch, enable_print=False): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + entries += [ + " | ".join( + [ + f"{os.path.join(name, subname)}: {val:.4f}" + for subname, val in meter.compute().items() + ] + ) + for name, meter in self.real_meters.items() + ] + logging.info(" | ".join(entries)) + if enable_print: + print(" | ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def get_resume_checkpoint(checkpoint_save_dir): + if not g_pathmgr.isdir(checkpoint_save_dir): + return None + ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") + if not g_pathmgr.isfile(ckpt_file): + return None + + return ckpt_file